[1134] | 1 | #include "ep_lib.hpp" |
---|
| 2 | #include <mpi.h> |
---|
[1295] | 3 | #include "ep_mpi.hpp" |
---|
[1134] | 4 | |
---|
[1295] | 5 | |
---|
[1134] | 6 | namespace ep_lib |
---|
| 7 | { |
---|
| 8 | |
---|
[1295] | 9 | int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm) |
---|
| 10 | { |
---|
| 11 | if(!comm.is_ep) |
---|
| 12 | { |
---|
| 13 | return ::MPI_Alltoall(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), to_mpi_comm(comm.mpi_comm)); |
---|
| 14 | } |
---|
[1289] | 15 | |
---|
| 16 | |
---|
[1295] | 17 | assert(valid_type(sendtype) && valid_type(recvtype)); |
---|
| 18 | assert(sendcount == recvcount); |
---|
| 19 | |
---|
| 20 | ::MPI_Aint datasize, llb; |
---|
| 21 | ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(sendtype), &llb, &datasize); |
---|
| 22 | |
---|
| 23 | int count = sendcount; |
---|
[1146] | 24 | |
---|
[1295] | 25 | int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; |
---|
| 26 | int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; |
---|
| 27 | int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; |
---|
| 28 | int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; |
---|
| 29 | int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; |
---|
| 30 | int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; |
---|
| 31 | |
---|
| 32 | void* tmp_recvbuf; |
---|
| 33 | if(ep_rank == 0) tmp_recvbuf = new void*[count * ep_size * ep_size * datasize]; |
---|
| 34 | |
---|
| 35 | MPI_Gather(sendbuf, count*ep_size, sendtype, tmp_recvbuf, count*ep_size, recvtype, 0, comm); |
---|
| 36 | |
---|
[1147] | 37 | |
---|
[1295] | 38 | |
---|
| 39 | // reorder tmp_buf |
---|
| 40 | void* tmp_sendbuf; |
---|
| 41 | if(ep_rank == 0) tmp_sendbuf = new void*[count * ep_size * ep_size * datasize]; |
---|
[1134] | 42 | |
---|
[1295] | 43 | if(ep_rank == 0) |
---|
[1146] | 44 | for(int i=0; i<ep_size; i++) |
---|
[1134] | 45 | { |
---|
[1295] | 46 | for(int j=0; j<ep_size; j++) |
---|
| 47 | { |
---|
| 48 | //printf("tmp_recv[%d] = tmp_send[%d]\n", i*ep_size*count + j*count, j*ep_size*count + i*count); |
---|
| 49 | |
---|
| 50 | memcpy(tmp_sendbuf + j*ep_size*count*datasize + i*count*datasize, tmp_recvbuf + i*ep_size*count*datasize + j*count*datasize, count*datasize); |
---|
| 51 | } |
---|
[1134] | 52 | } |
---|
| 53 | |
---|
[1295] | 54 | MPI_Scatter(tmp_sendbuf, ep_size*count, sendtype, recvbuf, ep_size*recvcount, recvtype, 0, comm); |
---|
| 55 | |
---|
| 56 | if(ep_rank == 0) |
---|
| 57 | { |
---|
| 58 | delete[] tmp_recvbuf; |
---|
| 59 | delete[] tmp_sendbuf; |
---|
| 60 | } |
---|
[1339] | 61 | |
---|
| 62 | MPI_Barrier(comm); |
---|
[1295] | 63 | |
---|
[1134] | 64 | return 0; |
---|
| 65 | } |
---|
| 66 | |
---|
| 67 | } |
---|
| 68 | |
---|
| 69 | |
---|