Changeset 1295 for XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp
- Timestamp:
- 10/06/17 13:56:33 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp
r1289 r1295 9 9 #include <mpi.h> 10 10 #include "ep_declaration.hpp" 11 #include "ep_mpi.hpp" 11 12 12 13 using namespace std; … … 15 16 { 16 17 17 int MPI_Scatter_local 2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm)18 int MPI_Scatter_local(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int local_root, MPI_Comm comm) 18 19 { 19 if(datatype == MPI_INT) 20 { 21 Debug("datatype is INT\n"); 22 return MPI_Scatter_local_int(sendbuf, count, recvbuf, comm); 23 } 24 else if(datatype == MPI_FLOAT) 25 { 26 Debug("datatype is FLOAT\n"); 27 return MPI_Scatter_local_float(sendbuf, count, recvbuf, comm); 28 } 29 else if(datatype == MPI_DOUBLE) 30 { 31 Debug("datatype is DOUBLE\n"); 32 return MPI_Scatter_local_double(sendbuf, count, recvbuf, comm); 33 } 34 else if(datatype == MPI_LONG) 35 { 36 Debug("datatype is LONG\n"); 37 return MPI_Scatter_local_long(sendbuf, count, recvbuf, comm); 38 } 39 else if(datatype == MPI_UNSIGNED_LONG) 40 { 41 Debug("datatype is uLONG\n"); 42 return MPI_Scatter_local_ulong(sendbuf, count, recvbuf, comm); 43 } 44 else if(datatype == MPI_CHAR) 45 { 46 Debug("datatype is CHAR\n"); 47 return MPI_Scatter_local_char(sendbuf, count, recvbuf, comm); 48 } 49 else 50 { 51 printf("MPI_Scatter Datatype not supported!\n"); 52 exit(0); 53 } 54 } 20 assert(valid_type(sendtype) && valid_type(recvtype)); 21 assert(recvcount == sendcount); 55 22 56 int MPI_Scatter_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 57 { 58 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 59 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 23 ::MPI_Aint datasize, lb; 24 ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize); 25 26 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 27 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 60 28 61 29 62 int *buffer = comm.my_buffer->buf_int; 63 int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 64 int *recv_buf = static_cast<int*>(recvbuf); 30 if(ep_rank_loc == local_root) 31 comm.my_buffer->void_buffer[local_root] = const_cast<void*>(sendbuf); 65 32 66 for(int k=0; k<num_ep; k++) 67 { 68 for(int j=0; j<count; j+=BUFFER_SIZE) 69 { 70 if(my_rank == 0) 71 { 72 #pragma omp critical (write_to_buffer) 73 { 74 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 75 #pragma omp flush 76 } 77 } 33 MPI_Barrier_local(comm); 78 34 79 MPI_Barrier_local(comm); 35 #pragma omp critical (_scatter) 36 memcpy(recvbuf, comm.my_buffer->void_buffer[local_root]+datasize*ep_rank_loc*sendcount, datasize * recvcount); 37 80 38 81 if(my_rank == k) 82 { 83 #pragma omp critical (read_from_buffer) 84 { 85 #pragma omp flush 86 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 87 } 88 } 89 MPI_Barrier_local(comm); 90 } 91 } 39 MPI_Barrier_local(comm); 92 40 } 93 94 int MPI_Scatter_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)95 {96 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;97 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;98 99 float *buffer = comm.my_buffer->buf_float;100 float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));101 float *recv_buf = static_cast<float*>(recvbuf);102 103 for(int k=0; k<num_ep; k++)104 {105 for(int j=0; j<count; j+=BUFFER_SIZE)106 {107 if(my_rank == 0)108 {109 #pragma omp critical (write_to_buffer)110 {111 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);112 #pragma omp flush113 }114 }115 116 MPI_Barrier_local(comm);117 118 if(my_rank == k)119 {120 #pragma omp critical (read_from_buffer)121 {122 #pragma omp flush123 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);124 }125 }126 MPI_Barrier_local(comm);127 }128 }129 }130 131 int MPI_Scatter_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)132 {133 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;134 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;135 136 double *buffer = comm.my_buffer->buf_double;137 double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));138 double *recv_buf = static_cast<double*>(recvbuf);139 140 for(int k=0; k<num_ep; k++)141 {142 for(int j=0; j<count; j+=BUFFER_SIZE)143 {144 if(my_rank == 0)145 {146 #pragma omp critical (write_to_buffer)147 {148 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);149 #pragma omp flush150 }151 }152 153 MPI_Barrier_local(comm);154 155 if(my_rank == k)156 {157 #pragma omp critical (read_from_buffer)158 {159 #pragma omp flush160 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);161 }162 }163 MPI_Barrier_local(comm);164 }165 }166 }167 168 int MPI_Scatter_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)169 {170 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;171 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;172 173 long *buffer = comm.my_buffer->buf_long;174 long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));175 long *recv_buf = static_cast<long*>(recvbuf);176 177 for(int k=0; k<num_ep; k++)178 {179 for(int j=0; j<count; j+=BUFFER_SIZE)180 {181 if(my_rank == 0)182 {183 #pragma omp critical (write_to_buffer)184 {185 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);186 #pragma omp flush187 }188 }189 190 MPI_Barrier_local(comm);191 192 if(my_rank == k)193 {194 #pragma omp critical (read_from_buffer)195 {196 #pragma omp flush197 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);198 }199 }200 MPI_Barrier_local(comm);201 }202 }203 }204 205 206 int MPI_Scatter_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)207 {208 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;209 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;210 211 unsigned long *buffer = comm.my_buffer->buf_ulong;212 unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));213 unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);214 215 for(int k=0; k<num_ep; k++)216 {217 for(int j=0; j<count; j+=BUFFER_SIZE)218 {219 if(my_rank == 0)220 {221 #pragma omp critical (write_to_buffer)222 {223 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);224 #pragma omp flush225 }226 }227 228 MPI_Barrier_local(comm);229 230 if(my_rank == k)231 {232 #pragma omp critical (read_from_buffer)233 {234 #pragma omp flush235 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);236 }237 }238 MPI_Barrier_local(comm);239 }240 }241 }242 243 244 int MPI_Scatter_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm)245 {246 int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;247 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second;248 249 char *buffer = comm.my_buffer->buf_char;250 char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));251 char *recv_buf = static_cast<char*>(recvbuf);252 253 for(int k=0; k<num_ep; k++)254 {255 for(int j=0; j<count; j+=BUFFER_SIZE)256 {257 if(my_rank == 0)258 {259 #pragma omp critical (write_to_buffer)260 {261 copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer);262 #pragma omp flush263 }264 }265 266 MPI_Barrier_local(comm);267 268 if(my_rank == k)269 {270 #pragma omp critical (read_from_buffer)271 {272 #pragma omp flush273 copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j);274 }275 }276 MPI_Barrier_local(comm);277 }278 }279 }280 281 282 283 284 41 285 42 int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) … … 287 44 if(!comm.is_ep) 288 45 { 289 ::MPI_Scatter(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 290 root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 291 return 0; 46 return ::MPI_Scatter(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), root, to_mpi_comm(comm.mpi_comm)); 292 47 } 48 49 assert(sendcount == recvcount); 293 50 294 if(!comm.mpi_comm) return 0; 295 296 assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 51 int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 52 int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 53 int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 54 int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 55 int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 56 int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 297 57 298 58 int root_mpi_rank = comm.rank_map->at(root).second; 299 59 int root_ep_loc = comm.rank_map->at(root).first; 300 60 301 int ep_rank, ep_rank_loc, mpi_rank; 302 int ep_size, num_ep, mpi_size; 303 304 ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 305 ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 306 mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 307 ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 308 num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 309 mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 310 61 bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 62 bool is_root = ep_rank == root; 311 63 312 64 MPI_Datatype datatype = sendtype; … … 314 66 315 67 ::MPI_Aint datasize, lb; 68 ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 69 70 void *tmp_sendbuf; 71 if(is_root) tmp_sendbuf = new void*[ep_size * count * datasize]; 316 72 317 ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 73 // reorder tmp_sendbuf 74 std::vector<int>local_ranks(num_ep); 75 std::vector<int>ranks(ep_size); 76 77 if(mpi_rank == root_mpi_rank) MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), root_ep_loc, comm); 78 else MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), 0, comm); 318 79 319 80 320 void *master_sendbuf;321 void *local_recvbuf;81 std::vector<int> recvcounts(mpi_size, 0); 82 std::vector<int> displs(mpi_size, 0); 322 83 323 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) 84 85 if(is_master) 324 86 { 325 if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count*ep_size]; 87 for(int i=0; i<ep_size; i++) 88 { 89 recvcounts[comm.rank_map->at(i).second]++; 90 } 326 91 327 innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count*ep_size, datatype, comm); 92 for(int i=1; i<mpi_size; i++) 93 displs[i] = displs[i-1] + recvcounts[i-1]; 94 95 ::MPI_Gatherv(local_ranks.data(), num_ep, MPI_INT, ranks.data(), recvcounts.data(), displs.data(), MPI_INT, root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 328 96 } 329 97 330 98 331 99 332 if(ep_rank_loc == 0) 100 // if(is_root) printf("\nranks = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", ranks[0], ranks[1], ranks[2], ranks[3], ranks[4], ranks[5], ranks[6], ranks[7], 101 // ranks[8], ranks[9], ranks[10], ranks[11], ranks[12], ranks[13], ranks[14], ranks[15]); 102 103 if(is_root) 104 for(int i=0; i<ep_size; i++) 333 105 { 334 int mpi_sendcnt = count*num_ep; 335 int mpi_scatterv_sendcnt[mpi_size]; 336 int displs[mpi_size]; 106 memcpy(tmp_sendbuf + i*datasize*count, sendbuf + ranks[i]*datasize*count, count*datasize); 107 } 337 108 338 local_recvbuf = new void*[datasize*mpi_sendcnt]; 339 340 ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 341 342 displs[0] = 0; 343 for(int i=1; i<mpi_size; i++) 344 displs[i] = displs[i-1] + mpi_scatterv_sendcnt[i-1]; 109 // if(is_root) printf("\ntmp_sendbuf = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[2], static_cast<int*>(tmp_sendbuf)[4], static_cast<int*>(tmp_sendbuf)[6], 110 // static_cast<int*>(tmp_sendbuf)[8], static_cast<int*>(tmp_sendbuf)[10], static_cast<int*>(tmp_sendbuf)[12], static_cast<int*>(tmp_sendbuf)[14], 111 // static_cast<int*>(tmp_sendbuf)[16], static_cast<int*>(tmp_sendbuf)[18], static_cast<int*>(tmp_sendbuf)[20], static_cast<int*>(tmp_sendbuf)[22], 112 // static_cast<int*>(tmp_sendbuf)[24], static_cast<int*>(tmp_sendbuf)[26], static_cast<int*>(tmp_sendbuf)[28], static_cast<int*>(tmp_sendbuf)[30] ); 345 113 346 114 347 if(root_ep_loc!=0) 348 { 349 ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 350 local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 351 } 352 else 353 { 354 ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 355 local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 356 } 115 // MPI_Scatterv from root to masters 116 void* local_recvbuf; 117 if(is_master) local_recvbuf = new void*[datasize * num_ep * count]; 118 119 120 if(is_master) 121 { 122 int local_sendcount = num_ep * count; 123 ::MPI_Gather(&local_sendcount, 1, to_mpi_type(MPI_INT), recvcounts.data(), 1, to_mpi_type(MPI_INT), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 124 125 if(is_root) for(int i=1; i<mpi_size; i++) displs[i] = displs[i-1] + recvcounts[i-1]; 126 127 ::MPI_Scatterv(tmp_sendbuf, recvcounts.data(), displs.data(), to_mpi_type(sendtype), local_recvbuf, num_ep*count, to_mpi_type(recvtype), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 128 129 // printf("local_recvbuf = %d %d %d %d\n", static_cast<int*>(local_recvbuf)[0], static_cast<int*>(local_recvbuf)[1], static_cast<int*>(local_recvbuf)[2], static_cast<int*>(local_recvbuf)[3]); 130 // static_cast<int*>(local_recvbuf)[4], static_cast<int*>(local_recvbuf)[5], static_cast<int*>(local_recvbuf)[6], static_cast<int*>(local_recvbuf)[7]); 357 131 } 358 132 359 MPI_Scatter_local2(local_recvbuf, count, datatype, recvbuf, comm); 133 if(mpi_rank == root_mpi_rank) MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, root_ep_loc, comm); 134 else MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, 0, comm); 360 135 361 if(ep_rank_loc == 0) 362 { 363 if(datatype == MPI_INT) 364 { 365 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf); 366 delete[] static_cast<int*>(local_recvbuf); 367 } 368 else if(datatype == MPI_FLOAT) 369 { 370 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf); 371 delete[] static_cast<float*>(local_recvbuf); 372 } 373 else if(datatype == MPI_DOUBLE) 374 { 375 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf); 376 delete[] static_cast<double*>(local_recvbuf); 377 } 378 else if(datatype == MPI_LONG) 379 { 380 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf); 381 delete[] static_cast<long*>(local_recvbuf); 382 } 383 else if(datatype == MPI_UNSIGNED_LONG) 384 { 385 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf); 386 delete[] static_cast<unsigned long*>(local_recvbuf); 387 } 388 else //if(datatype == MPI_DOUBLE) 389 { 390 if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf); 391 delete[] static_cast<char*>(local_recvbuf); 392 } 393 } 136 if(is_root) delete[] tmp_sendbuf; 137 if(is_master) delete[] local_recvbuf; 138 } 394 139 395 }396 140 }
Note: See TracChangeset
for help on using the changeset viewer.