| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "zagryadskov_m_complex_spmm_ccs/all/include/ops_all.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | #include <omp.h> | ||
| 5 | |||
| 6 | #include <algorithm> | ||
| 7 | #include <complex> | ||
| 8 | #include <stdexcept> | ||
| 9 | #include <tuple> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "util/include/util.hpp" | ||
| 13 | #include "zagryadskov_m_complex_spmm_ccs/common/include/common.hpp" | ||
| 14 | |||
| 15 | namespace zagryadskov_m_complex_spmm_ccs { | ||
| 16 | |||
| 17 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | ZagryadskovMComplexSpMMCCSALL::ZagryadskovMComplexSpMMCCSALL(const InType &in) { |
| 18 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 19 | 6 | int world_rank = 0; | |
| 20 | int err_code = 0; | ||
| 21 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | err_code = MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); |
| 22 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (err_code != MPI_SUCCESS) { |
| 23 | ✗ | throw std::runtime_error("MPI_Comm_rank failed"); | |
| 24 | } | ||
| 25 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (world_rank == 0) { |
| 26 | GetInput() = in; | ||
| 27 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | GetOutput() = CCS(); |
| 28 | } else { | ||
| 29 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
6 | GetInput() = std::make_tuple(CCS(), CCS()); |
| 30 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | GetOutput() = CCS(); |
| 31 | } | ||
| 32 | 6 | } | |
| 33 | |||
| 34 | 12 | void ZagryadskovMComplexSpMMCCSALL::SpMMSymbolic(const CCS &a, const CCS &b, std::vector<int> &col_ptr, int jstart, | |
| 35 | int jend) { | ||
| 36 | 12 | std::vector<int> marker(a.m, -1); | |
| 37 | |||
| 38 |
2/2✓ Branch 0 taken 7 times.
✓ Branch 1 taken 12 times.
|
19 | for (int j = jstart; j < jend; ++j) { |
| 39 | int count = 0; | ||
| 40 | |||
| 41 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 7 times.
|
17 | for (int k = b.col_ptr[j]; k < b.col_ptr[j + 1]; ++k) { |
| 42 | 10 | int b_row = b.row_ind[k]; | |
| 43 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 10 times.
|
21 | for (int zp = a.col_ptr[b_row]; zp < a.col_ptr[b_row + 1]; ++zp) { |
| 44 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
|
11 | int a_row = a.row_ind[zp]; |
| 45 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
|
11 | if (marker[a_row] != j) { |
| 46 | 9 | marker[a_row] = j; | |
| 47 | 9 | ++count; | |
| 48 | } | ||
| 49 | } | ||
| 50 | } | ||
| 51 | 7 | col_ptr[j + 1] += count; | |
| 52 | } | ||
| 53 | 12 | } | |
| 54 | |||
| 55 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
|
7 | void ZagryadskovMComplexSpMMCCSALL::SpMMKernel(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero, |
| 56 | std::vector<int> &rows, std::vector<std::complex<double>> &acc, | ||
| 57 | std::vector<int> &marker, int j) { | ||
| 58 | rows.clear(); | ||
| 59 | 7 | int write_ptr = c.col_ptr[j]; | |
| 60 | |||
| 61 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 7 times.
|
17 | for (int k = b.col_ptr[j]; k < b.col_ptr[j + 1]; ++k) { |
| 62 | 10 | std::complex<double> tmpval = b.values[k]; | |
| 63 | 10 | int b_row = b.row_ind[k]; | |
| 64 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 10 times.
|
21 | for (int zp = a.col_ptr[b_row]; zp < a.col_ptr[b_row + 1]; ++zp) { |
| 65 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
|
11 | int a_row = a.row_ind[zp]; |
| 66 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
|
11 | acc[a_row] += tmpval * a.values[zp]; |
| 67 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
|
11 | if (marker[a_row] != j) { |
| 68 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
|
9 | marker[a_row] = j; |
| 69 | rows.push_back(a_row); | ||
| 70 | } | ||
| 71 | } | ||
| 72 | } | ||
| 73 | |||
| 74 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 7 times.
|
16 | for (int r_idx : rows) { |
| 75 | 9 | c.row_ind[write_ptr] = r_idx; | |
| 76 | 9 | c.values[write_ptr] = acc[r_idx]; | |
| 77 | 9 | ++write_ptr; | |
| 78 | 9 | acc[r_idx] = zero; | |
| 79 | } | ||
| 80 | 7 | } | |
| 81 | |||
| 82 | 12 | void ZagryadskovMComplexSpMMCCSALL::SpMMNumeric(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero, | |
| 83 | int jstart, int jend) { | ||
| 84 | 12 | std::vector<int> marker(a.m, -1); | |
| 85 |
1/4✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
12 | std::vector<std::complex<double>> acc(a.m, zero); |
| 86 | 12 | std::vector<int> rows; | |
| 87 | |||
| 88 |
2/2✓ Branch 0 taken 7 times.
✓ Branch 1 taken 12 times.
|
19 | for (int j = jstart; j < jend; ++j) { |
| 89 |
1/2✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
|
7 | SpMMKernel(a, b, c, zero, rows, acc, marker, j); |
| 90 | } | ||
| 91 | 12 | } | |
| 92 | |||
| 93 | 6 | void ZagryadskovMComplexSpMMCCSALL::SpMM(const CCS &a, const CCS &b, CCS &c) { | |
| 94 | 6 | c.m = a.m; | |
| 95 | 6 | c.n = b.n; | |
| 96 | 6 | const int num_threads = ppc::util::GetNumThreads(); | |
| 97 | |||
| 98 | 6 | std::complex<double> zero(0.0, 0.0); | |
| 99 | 6 | c.col_ptr.assign(c.n + 1, 0); | |
| 100 | |||
| 101 | 6 | #pragma omp parallel default(none) shared(num_threads, a, b, c) num_threads(ppc::util::GetNumThreads()) | |
| 102 | { | ||
| 103 | int tid = omp_get_thread_num(); | ||
| 104 | int jstart = (tid * b.n) / num_threads; | ||
| 105 | int jend = ((tid + 1) * b.n) / num_threads; | ||
| 106 | SpMMSymbolic(a, b, c.col_ptr, jstart, jend); | ||
| 107 | } | ||
| 108 | |||
| 109 |
2/2✓ Branch 0 taken 7 times.
✓ Branch 1 taken 6 times.
|
13 | for (int j = 0; j < c.n; ++j) { |
| 110 | 7 | c.col_ptr[j + 1] += c.col_ptr[j]; | |
| 111 | } | ||
| 112 | 6 | int nnz = c.col_ptr[b.n]; | |
| 113 | 6 | c.row_ind.resize(nnz); | |
| 114 | 6 | c.values.resize(nnz); | |
| 115 | 6 | #pragma omp parallel default(none) shared(num_threads, a, b, c, zero) num_threads(ppc::util::GetNumThreads()) | |
| 116 | { | ||
| 117 | int tid = omp_get_thread_num(); | ||
| 118 | int jstart = (tid * b.n) / num_threads; | ||
| 119 | int jend = ((tid + 1) * b.n) / num_threads; | ||
| 120 | SpMMNumeric(a, b, c, zero, jstart, jend); | ||
| 121 | } | ||
| 122 | 6 | } | |
| 123 | |||
| 124 | 6 | void ZagryadskovMComplexSpMMCCSALL::BcastCCS(CCS &a, int rank) { | |
| 125 | 6 | MPI_Bcast(&a.m, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 126 | 6 | MPI_Bcast(&a.n, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 127 | |||
| 128 | 6 | int nz = 0; | |
| 129 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == 0) { |
| 130 | 3 | nz = static_cast<int>(a.values.size()); | |
| 131 | } | ||
| 132 | 6 | MPI_Bcast(&nz, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 133 | |||
| 134 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank != 0) { |
| 135 | 3 | a.col_ptr.resize(a.n + 1); | |
| 136 | 3 | a.row_ind.resize(nz); | |
| 137 | 3 | a.values.resize(nz); | |
| 138 | } | ||
| 139 | |||
| 140 | 6 | MPI_Bcast(a.col_ptr.data(), a.n + 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 141 | 6 | MPI_Bcast(a.row_ind.data(), nz, MPI_INT, 0, MPI_COMM_WORLD); | |
| 142 | 6 | MPI_Bcast(a.values.data(), nz, MPI_C_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); | |
| 143 | 6 | } | |
| 144 | |||
| 145 | 3 | void ZagryadskovMComplexSpMMCCSALL::SendCCS(const CCS &m, int dest) { | |
| 146 | 3 | MPI_Send(&m.m, 1, MPI_INT, dest, 0, MPI_COMM_WORLD); | |
| 147 | 3 | MPI_Send(&m.n, 1, MPI_INT, dest, 0, MPI_COMM_WORLD); | |
| 148 | 3 | int nz = static_cast<int>(m.values.size()); | |
| 149 | 3 | MPI_Send(&nz, 1, MPI_INT, dest, 0, MPI_COMM_WORLD); | |
| 150 | |||
| 151 | 3 | MPI_Send(m.col_ptr.data(), m.n + 1, MPI_INT, dest, 0, MPI_COMM_WORLD); | |
| 152 | 3 | MPI_Send(m.row_ind.data(), nz, MPI_INT, dest, 0, MPI_COMM_WORLD); | |
| 153 | 3 | MPI_Send(m.values.data(), nz, MPI_C_DOUBLE_COMPLEX, dest, 0, MPI_COMM_WORLD); | |
| 154 | 3 | } | |
| 155 | |||
| 156 | 3 | void ZagryadskovMComplexSpMMCCSALL::RecvCCS(CCS &m, int src) { | |
| 157 | MPI_Status st; | ||
| 158 | 3 | MPI_Recv(&m.m, 1, MPI_INT, src, 0, MPI_COMM_WORLD, &st); | |
| 159 | 3 | MPI_Recv(&m.n, 1, MPI_INT, src, 0, MPI_COMM_WORLD, &st); | |
| 160 | 3 | int nz = 0; | |
| 161 | 3 | MPI_Recv(&nz, 1, MPI_INT, src, 0, MPI_COMM_WORLD, &st); | |
| 162 | |||
| 163 | 3 | m.col_ptr.resize(m.n + 1); | |
| 164 | 3 | m.row_ind.resize(nz); | |
| 165 | 3 | m.values.resize(nz); | |
| 166 | |||
| 167 | 3 | MPI_Recv(m.col_ptr.data(), m.n + 1, MPI_INT, src, 0, MPI_COMM_WORLD, &st); | |
| 168 | 3 | MPI_Recv(m.row_ind.data(), nz, MPI_INT, src, 0, MPI_COMM_WORLD, &st); | |
| 169 | 3 | MPI_Recv(m.values.data(), nz, MPI_C_DOUBLE_COMPLEX, src, 0, MPI_COMM_WORLD, &st); | |
| 170 | 3 | } | |
| 171 | |||
| 172 | 6 | void ZagryadskovMComplexSpMMCCSALL::ScatterB(const CCS &b, CCS &b_local, const std::vector<int> &col_starts, int rank, | |
| 173 | int size) { | ||
| 174 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == 0) { |
| 175 | 3 | CCS tmp; | |
| 176 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (int proc = 0; proc < size; ++proc) { |
| 177 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | int jstart = col_starts[proc]; |
| 178 | 6 | int jend = col_starts[proc + 1]; | |
| 179 | |||
| 180 | 6 | tmp.m = b.m; | |
| 181 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | tmp.n = jend - jstart; |
| 182 | tmp.row_ind.clear(); | ||
| 183 | tmp.values.clear(); | ||
| 184 | tmp.col_ptr.clear(); | ||
| 185 | |||
| 186 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | int nnz_start = b.col_ptr[jstart]; |
| 187 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | int nnz_end = b.col_ptr[jend]; |
| 188 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | tmp.row_ind.assign(b.row_ind.begin() + nnz_start, b.row_ind.begin() + nnz_end); |
| 189 | 6 | tmp.values.assign(b.values.begin() + nnz_start, b.values.begin() + nnz_end); | |
| 190 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | tmp.col_ptr.resize(tmp.n + 1); |
| 191 |
2/2✓ Branch 0 taken 13 times.
✓ Branch 1 taken 6 times.
|
19 | for (int j = 0; j <= tmp.n; ++j) { |
| 192 | 13 | tmp.col_ptr[j] = b.col_ptr[jstart + j] - nnz_start; | |
| 193 | } | ||
| 194 | |||
| 195 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (proc == 0) { |
| 196 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | b_local = tmp; |
| 197 | } else { | ||
| 198 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | SendCCS(tmp, proc); |
| 199 | } | ||
| 200 | } | ||
| 201 | 3 | } else { | |
| 202 | 3 | RecvCCS(b_local, 0); | |
| 203 | } | ||
| 204 | 6 | } | |
| 205 | |||
| 206 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | void ZagryadskovMComplexSpMMCCSALL::GatherC(CCS &c, CCS &c_local, int rank, int size) { |
| 207 | MPI_Status st; | ||
| 208 | 6 | int local_nnz = static_cast<int>(c_local.values.size()); | |
| 209 | int total_nnz = 0; | ||
| 210 | 6 | int local_cols = c_local.n; | |
| 211 | int total_cols = 0; | ||
| 212 | 6 | std::vector<int> tmp; | |
| 213 |
1/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
6 | std::vector<int> recvcounts(size); |
| 214 |
2/6✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
6 | std::vector<int> displs(size); |
| 215 | |||
| 216 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Gather(&local_nnz, 1, MPI_INT, recvcounts.data(), 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 217 | |||
| 218 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == 0) { |
| 219 | 3 | c.m = c_local.m; | |
| 220 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (int i = 0; i < size; ++i) { |
| 221 | 6 | displs[i] = total_nnz; | |
| 222 | 6 | total_nnz += recvcounts[i]; | |
| 223 | } | ||
| 224 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | c.row_ind.resize(total_nnz); |
| 225 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | c.values.resize(total_nnz); |
| 226 | } | ||
| 227 | |||
| 228 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Gatherv(c_local.row_ind.data(), local_nnz, MPI_INT, c.row_ind.data(), recvcounts.data(), displs.data(), MPI_INT, |
| 229 | 0, MPI_COMM_WORLD); | ||
| 230 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Gatherv(c_local.values.data(), local_nnz, MPI_C_DOUBLE_COMPLEX, c.values.data(), recvcounts.data(), displs.data(), |
| 231 | MPI_C_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); | ||
| 232 | |||
| 233 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Gather(&local_cols, 1, MPI_INT, recvcounts.data(), 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 234 | |||
| 235 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == 0) { |
| 236 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (int i = 0; i < size; ++i) { |
| 237 | 6 | displs[i] = total_cols + 1; | |
| 238 | 6 | total_cols += recvcounts[i]; | |
| 239 | 6 | recvcounts[i] += 1; | |
| 240 | } | ||
| 241 | 3 | c.n = total_cols; | |
| 242 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | c.col_ptr.resize(total_cols + 1); |
| 243 | } | ||
| 244 | |||
| 245 | if (rank != 0) { | ||
| 246 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Send(c_local.col_ptr.data(), c_local.n + 1, MPI_INT, 0, 0, MPI_COMM_WORLD); |
| 247 | } | ||
| 248 | |||
| 249 | if (rank == 0) { | ||
| 250 | std::ranges::copy(c_local.col_ptr, c.col_ptr.begin()); | ||
| 251 | |||
| 252 | 3 | int nz_offset = c_local.col_ptr.back(); | |
| 253 | 3 | int col_offset = c_local.n; | |
| 254 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | for (int proc = 1; proc < size; ++proc) { |
| 255 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | tmp.resize(recvcounts[proc]); |
| 256 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Recv(tmp.data(), recvcounts[proc], MPI_INT, proc, 0, MPI_COMM_WORLD, &st); |
| 257 | |||
| 258 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 3 times.
|
7 | for (int j = 1; j < recvcounts[proc]; ++j) { |
| 259 | 4 | c.col_ptr[col_offset + j] = nz_offset + tmp[j]; | |
| 260 | } | ||
| 261 | |||
| 262 | 3 | nz_offset += tmp.back(); | |
| 263 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | col_offset += recvcounts[proc] - 1; |
| 264 | tmp.clear(); | ||
| 265 | } | ||
| 266 | } | ||
| 267 | 6 | } | |
| 268 | |||
| 269 | 6 | bool ZagryadskovMComplexSpMMCCSALL::ValidationImpl() { | |
| 270 | bool res = false; | ||
| 271 | 6 | int world_rank = 0; | |
| 272 | int err_code = 0; | ||
| 273 | 6 | err_code = MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | |
| 274 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (err_code != MPI_SUCCESS) { |
| 275 | ✗ | throw std::runtime_error("MPI_Comm_rank failed"); | |
| 276 | } | ||
| 277 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (world_rank == 0) { |
| 278 | const CCS &a = std::get<0>(GetInput()); | ||
| 279 | const CCS &b = std::get<1>(GetInput()); | ||
| 280 | 3 | res = a.n == b.m; | |
| 281 | } else { | ||
| 282 | res = true; | ||
| 283 | } | ||
| 284 | 6 | return res; | |
| 285 | } | ||
| 286 | |||
| 287 | 6 | bool ZagryadskovMComplexSpMMCCSALL::PreProcessingImpl() { | |
| 288 | 6 | return true; | |
| 289 | } | ||
| 290 | |||
| 291 | 6 | bool ZagryadskovMComplexSpMMCCSALL::RunImpl() { | |
| 292 | 6 | int world_rank = 0; | |
| 293 | 6 | int world_size = 0; | |
| 294 | 6 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | |
| 295 | 6 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 296 | CCS &a = std::get<0>(GetInput()); | ||
| 297 | CCS &b = std::get<1>(GetInput()); | ||
| 298 | CCS &c = GetOutput(); | ||
| 299 | |||
| 300 | 6 | CCS local_b; | |
| 301 | 6 | CCS local_c; | |
| 302 | 6 | std::vector<int> col_starts; | |
| 303 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (world_rank == 0) { |
| 304 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | col_starts.resize(world_size + 1); |
| 305 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 3 times.
|
12 | for (int proc = 0; proc <= world_size; ++proc) { |
| 306 | 9 | col_starts[proc] = (proc * b.n) / world_size; | |
| 307 | } | ||
| 308 | } | ||
| 309 | |||
| 310 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | BcastCCS(a, world_rank); |
| 311 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | ScatterB(b, local_b, col_starts, world_rank, world_size); |
| 312 | |||
| 313 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | ZagryadskovMComplexSpMMCCSALL::SpMM(a, local_b, local_c); |
| 314 | |||
| 315 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | GatherC(c, local_c, world_rank, world_size); |
| 316 | |||
| 317 | 6 | return true; | |
| 318 | 6 | } | |
| 319 | |||
| 320 | 6 | bool ZagryadskovMComplexSpMMCCSALL::PostProcessingImpl() { | |
| 321 | 6 | int world_rank = 0; | |
| 322 | 6 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | |
| 323 | 6 | int m = 0; | |
| 324 | 6 | int n = 0; | |
| 325 | 6 | int nz = 0; | |
| 326 | CCS &c = GetOutput(); | ||
| 327 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (world_rank == 0) { |
| 328 | 3 | m = c.m; | |
| 329 | 3 | n = c.n; | |
| 330 | 3 | nz = static_cast<int>(c.values.size()); | |
| 331 | } | ||
| 332 | 6 | MPI_Bcast(&m, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 333 | 6 | MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 334 | 6 | MPI_Bcast(&nz, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 335 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (world_rank != 0) { |
| 336 | 3 | c.m = m; | |
| 337 | 3 | c.n = n; | |
| 338 | 3 | c.col_ptr.resize(n + 1); | |
| 339 | 3 | c.row_ind.resize(nz); | |
| 340 | 3 | c.values.resize(nz); | |
| 341 | } | ||
| 342 | 6 | MPI_Bcast(c.col_ptr.data(), n + 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 343 | 6 | MPI_Bcast(c.row_ind.data(), nz, MPI_INT, 0, MPI_COMM_WORLD); | |
| 344 | 6 | MPI_Bcast(c.values.data(), nz, MPI_C_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); | |
| 345 | |||
| 346 | 6 | return true; | |
| 347 | } | ||
| 348 | |||
| 349 | } // namespace zagryadskov_m_complex_spmm_ccs | ||
| 350 |