source: XIOS/dev/branch_openmp/extern/src_ep_dev/ep_exscan.cpp @ 1784

Last change on this file since 1784 was 1642, checked in by yushan, 6 years ago

dev on ADA. add flag switch _usingEP/_usingMPI

File size: 10.8 KB
RevLine 
[1134]1/*!
2   \file ep_scan.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Exscan
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
[1295]11#include "ep_mpi.hpp"
[1134]12
13using namespace std;
14
15namespace ep_lib
16{
17  template<typename T>
18  T max_op(T a, T b)
19  {
20    return max(a,b);
21  }
22
23  template<typename T>
24  T min_op(T a, T b)
25  {
26    return min(a,b);
27  }
28
[1295]29  template<typename T>
30  void reduce_max(const T * buffer, T* recvbuf, int count)
[1134]31  {
[1295]32    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
[1134]33  }
34
[1295]35  template<typename T>
36  void reduce_min(const T * buffer, T* recvbuf, int count)
37  {
38    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
39  }
[1134]40
[1295]41  template<typename T>
42  void reduce_sum(const T * buffer, T* recvbuf, int count)
43  {
44    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
45  }
[1134]46
47
[1295]48  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
[1134]49  {
[1295]50    valid_op(op);
[1134]51
[1520]52    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
53    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
54    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
[1295]55   
[1134]56
[1295]57    ::MPI_Aint datasize, lb;
58    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
59
60    if(ep_rank_loc == 0 && mpi_rank != 0)
[1289]61    {
[1520]62      comm->my_buffer->void_buffer[0] = recvbuf;
[1295]63    }
64    if(ep_rank_loc == 0 && mpi_rank == 0)
65    {
[1520]66      comm->my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
[1295]67    } 
68     
[1287]69
[1295]70    MPI_Barrier_local(comm);
[1289]71
[1520]72    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
[1289]73
[1295]74    MPI_Barrier_local(comm);
[1289]75
[1520]76    comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
[1295]77   
78    MPI_Barrier_local(comm);
79
[1642]80    if(op == EP_SUM)
[1295]81    {
[1642]82      if(datatype == EP_INT )
[1289]83      {
[1295]84        assert(datasize == sizeof(int));
85        for(int i=0; i<ep_rank_loc; i++)
[1520]86          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
[1295]87      }
88     
[1642]89      else if(datatype == EP_FLOAT )
[1295]90      {
91        assert(datasize == sizeof(float));
92        for(int i=0; i<ep_rank_loc; i++)
[1520]93          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1295]94      }
95     
[1289]96
[1642]97      else if(datatype == EP_DOUBLE )
[1295]98      {
99        assert(datasize == sizeof(double));
100        for(int i=0; i<ep_rank_loc; i++)
[1520]101          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]102      }
[1289]103
[1642]104      else if(datatype == EP_CHAR )
[1295]105      {
106        assert(datasize == sizeof(char));
107        for(int i=0; i<ep_rank_loc; i++)
[1520]108          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1289]109      }
[1134]110
[1642]111      else if(datatype == EP_LONG )
[1134]112      {
[1295]113        assert(datasize == sizeof(long));
114        for(int i=0; i<ep_rank_loc; i++)
[1520]115          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]116      }
[1289]117
[1642]118      else if(datatype == EP_UNSIGNED_LONG )
[1134]119      {
[1295]120        assert(datasize == sizeof(unsigned long));
121        for(int i=0; i<ep_rank_loc; i++)
[1520]122          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1134]123      }
[1540]124     
[1642]125      else if(datatype == EP_LONG_LONG_INT )
[1540]126      {
127        assert(datasize == sizeof(long long int));
128        for(int i=0; i<ep_rank_loc; i++)
129          reduce_sum<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
130      }
[1134]131
[1540]132      else 
133      {
134        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
135        MPI_Abort(comm, 0);
136      }
[1289]137
[1295]138     
139    }
[1289]140
[1642]141    else if(op == EP_MAX)
[1289]142    {
[1642]143      if(datatype == EP_INT )
[1134]144      {
[1295]145        assert(datasize == sizeof(int));
146        for(int i=0; i<ep_rank_loc; i++)
[1520]147          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
[1295]148      }
[1289]149
[1642]150      else if(datatype == EP_FLOAT )
[1295]151      {
152        assert(datasize == sizeof(float));
153        for(int i=0; i<ep_rank_loc; i++)
[1520]154          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1134]155      }
156
[1642]157      else if(datatype == EP_DOUBLE )
[1295]158      {
159        assert(datasize == sizeof(double));
160        for(int i=0; i<ep_rank_loc; i++)
[1520]161          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]162      }
[1289]163
[1642]164      else if(datatype == EP_CHAR )
[1134]165      {
[1295]166        assert(datasize == sizeof(char));
167        for(int i=0; i<ep_rank_loc; i++)
[1520]168          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1134]169      }
170
[1642]171      else if(datatype == EP_LONG )
[1134]172      {
[1295]173        assert(datasize == sizeof(long));
174        for(int i=0; i<ep_rank_loc; i++)
[1520]175          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1134]176      }
177
[1642]178      else if(datatype == EP_UNSIGNED_LONG )
[1134]179      {
[1295]180        assert(datasize == sizeof(unsigned long));
181        for(int i=0; i<ep_rank_loc; i++)
[1520]182          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1134]183      }
[1295]184     
[1642]185      else if(datatype == EP_LONG_LONG_INT )
[1540]186      {
187        assert(datasize == sizeof(long long int));
188        for(int i=0; i<ep_rank_loc; i++)
189          reduce_max<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
190      }
191
192      else 
193      {
194        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
195        MPI_Abort(comm, 0);
196      }
[1289]197    }
[1134]198
[1642]199    else if(op == EP_MIN)
[1134]200    {
[1642]201      if(datatype == EP_INT )
[1289]202      {
[1295]203        assert(datasize == sizeof(int));
204        for(int i=0; i<ep_rank_loc; i++)
[1520]205          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
[1295]206      }
[1134]207
[1642]208      else if(datatype == EP_FLOAT )
[1295]209      {
210        assert(datasize == sizeof(float));
211        for(int i=0; i<ep_rank_loc; i++)
[1520]212          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
[1289]213      }
[1134]214
[1642]215      else if(datatype == EP_DOUBLE )
[1295]216      {
217        assert(datasize == sizeof(double));
218        for(int i=0; i<ep_rank_loc; i++)
[1520]219          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
[1295]220      }
[1134]221
[1642]222      else if(datatype == EP_CHAR )
[1289]223      {
[1295]224        assert(datasize == sizeof(char));
225        for(int i=0; i<ep_rank_loc; i++)
[1520]226          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
[1295]227      }
[1134]228
[1642]229      else if(datatype == EP_LONG )
[1295]230      {
231        assert(datasize == sizeof(long));
232        for(int i=0; i<ep_rank_loc; i++)
[1520]233          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
[1289]234      }
[1134]235
[1642]236      else if(datatype == EP_UNSIGNED_LONG )
[1289]237      {
[1295]238        assert(datasize == sizeof(unsigned long));
239        for(int i=0; i<ep_rank_loc; i++)
[1520]240          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
[1289]241      }
[1134]242
[1642]243      else if(datatype == EP_LONG_LONG_INT )
[1540]244      {
245        assert(datasize == sizeof(long long int));
246        for(int i=0; i<ep_rank_loc; i++)
247          reduce_min<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
248      }
249
250      else 
251      {
252        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
253        MPI_Abort(comm, 0);
254      }
[1295]255    }
[1540]256   
257    else
258    {
259      printf("op type Error in ep_exscan : MPI_MAX, MPI_MIN, MPI_SUM\n");
260      MPI_Abort(comm, 0);
261    }
[1134]262
[1295]263    MPI_Barrier_local(comm);
[1289]264
265  }
[1134]266
267  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
268  {
[1539]269    if(!comm->is_ep) return ::MPI_Exscan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
270    if(comm->is_intercomm) return MPI_Exscan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
[1295]271   
[1540]272    assert(valid_type(datatype));
273    assert(valid_op(op));
[1134]274
[1520]275    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
276    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
277    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
278    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
279    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
280    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
[1134]281
282    ::MPI_Aint datasize, lb;
[1295]283    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
[1134]284   
[1295]285    void* tmp_sendbuf;
286    tmp_sendbuf = new void*[datasize * count];
[1134]287
[1295]288    int my_src = 0;
289    int my_dst = ep_rank;
[1134]290
[1295]291    std::vector<int> my_map(mpi_size, 0);
[1134]292
[1520]293    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
[1134]294
[1295]295    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
296    my_src += ep_rank_loc;
[1134]297
[1295]298     
299    for(int i=0; i<mpi_size; i++)
[1134]300    {
[1295]301      if(my_dst < my_map[i])
302      {
303        my_dst = get_ep_rank(comm, my_dst, i); 
304        break;
305      }
306      else
307        my_dst -= my_map[i];
[1289]308    }
309
[1295]310    if(ep_rank != my_dst) 
[1289]311    {
[1295]312      MPI_Request request[2];
313      MPI_Status status[2];
[1289]314
[1295]315      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
316   
317      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
318   
319      MPI_Waitall(2, request, status);
[1289]320    }
321
[1295]322    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
323   
[1289]324
[1295]325    void* tmp_recvbuf;
326    tmp_recvbuf = new void*[datasize * count];   
[1289]327
[1295]328    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
[1289]329
[1295]330    if(ep_rank_loc == 0)
[1539]331    {
[1520]332      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
[1539]333    }
[1295]334   
335    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
[1134]336
[1289]337
[1295]338    if(ep_rank != my_src) 
[1289]339    {
[1295]340      MPI_Request request[2];
341      MPI_Status status[2];
[1134]342
[1295]343      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
344   
345      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
346   
347      MPI_Waitall(2, request, status);
[1289]348    }
[1134]349
[1295]350    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
[1134]351
[1295]352    delete[] tmp_sendbuf;
353    delete[] tmp_recvbuf;
[1134]354
[1289]355  }
[1134]356
[1539]357
358  int MPI_Exscan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
359  {
360    printf("MPI_Exscan_intercomm not yet implemented\n");
361    MPI_Abort(comm, 0);
362  }
363
[1520]364}
Note: See TracBrowser for help on using the repository browser.