1 | #include "ep_lib.hpp" |
---|
2 | #include <mpi.h> |
---|
3 | #include "ep_mpi.hpp" |
---|
4 | |
---|
5 | |
---|
6 | namespace ep_lib |
---|
7 | { |
---|
8 | |
---|
9 | int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm) |
---|
10 | { |
---|
11 | if(!comm->is_ep) return ::MPI_Alltoall(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), to_mpi_comm(comm->mpi_comm)); |
---|
12 | if(comm->is_intercomm) return MPI_Alltoall_intercomm(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm); |
---|
13 | |
---|
14 | |
---|
15 | assert(valid_type(sendtype) && valid_type(recvtype)); |
---|
16 | assert(sendcount == recvcount); |
---|
17 | |
---|
18 | ::MPI_Aint datasize, llb; |
---|
19 | ::MPI_Type_get_extent(to_mpi_type(sendtype), &llb, &datasize); |
---|
20 | |
---|
21 | int count = sendcount; |
---|
22 | |
---|
23 | int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first; |
---|
24 | int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first; |
---|
25 | int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first; |
---|
26 | int ep_size = comm->ep_comm_ptr->size_rank_info[0].second; |
---|
27 | int num_ep = comm->ep_comm_ptr->size_rank_info[1].second; |
---|
28 | int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second; |
---|
29 | |
---|
30 | void* tmp_recvbuf; |
---|
31 | if(ep_rank == 0) tmp_recvbuf = new void*[count * ep_size * ep_size * datasize]; |
---|
32 | |
---|
33 | MPI_Gather(sendbuf, count*ep_size, sendtype, tmp_recvbuf, count*ep_size, recvtype, 0, comm); |
---|
34 | |
---|
35 | // reorder tmp_buf |
---|
36 | void* tmp_sendbuf; |
---|
37 | if(ep_rank == 0) tmp_sendbuf = new void*[count * ep_size * ep_size * datasize]; |
---|
38 | |
---|
39 | if(ep_rank == 0) |
---|
40 | for(int i=0; i<ep_size; i++) |
---|
41 | { |
---|
42 | for(int j=0; j<ep_size; j++) |
---|
43 | { |
---|
44 | memcpy(tmp_sendbuf + j*ep_size*count*datasize + i*count*datasize, tmp_recvbuf + i*ep_size*count*datasize + j*count*datasize, count*datasize); |
---|
45 | } |
---|
46 | } |
---|
47 | |
---|
48 | MPI_Scatter(tmp_sendbuf, ep_size*count, sendtype, recvbuf, ep_size*recvcount, recvtype, 0, comm); |
---|
49 | |
---|
50 | if(ep_rank == 0) |
---|
51 | { |
---|
52 | delete[] tmp_recvbuf; |
---|
53 | delete[] tmp_sendbuf; |
---|
54 | } |
---|
55 | } |
---|
56 | |
---|
57 | |
---|
58 | int MPI_Alltoall_intercomm(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm) |
---|
59 | { |
---|
60 | printf("MPI_Alltoall_intercomm not yet implemented\n"); |
---|
61 | MPI_Abort(comm, 0); |
---|
62 | } |
---|
63 | } |
---|
64 | |
---|
65 | |
---|