source: XIOS/dev/dev_trunk_omp/extern/src_ep_dev2/ep_exscan.cpp @ 1651

Last change on this file since 1651 was 1651, checked in by yushan, 5 years ago

dev on EP for tracing with itac. Tested on ADA with test_omp

File size: 10.9 KB
Line 
1#ifdef _usingEP
2/*!
3   \file ep_scan.cpp
4   \since 2 may 2016
5
6   \brief Definitions of MPI collective function: MPI_Exscan
7 */
8
9#include "ep_lib.hpp"
10#include <mpi.h>
11#include "ep_declaration.hpp"
12#include "ep_mpi.hpp"
13
14using namespace std;
15
16namespace ep_lib
17{
18  template<typename T>
19  T max_op(T a, T b)
20  {
21    return max(a,b);
22  }
23
24  template<typename T>
25  T min_op(T a, T b)
26  {
27    return min(a,b);
28  }
29
30  template<typename T>
31  void reduce_max(const T * buffer, T* recvbuf, int count)
32  {
33    transform(buffer, buffer+count, recvbuf, recvbuf, max_op<T>);
34  }
35
36  template<typename T>
37  void reduce_min(const T * buffer, T* recvbuf, int count)
38  {
39    transform(buffer, buffer+count, recvbuf, recvbuf, min_op<T>);
40  }
41
42  template<typename T>
43  void reduce_sum(const T * buffer, T* recvbuf, int count)
44  {
45    transform(buffer, buffer+count, recvbuf, recvbuf, std::plus<T>());
46  }
47
48
49  int MPI_Exscan_local(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
50  {
51    valid_op(op);
52
53    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
54    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
55    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
56   
57
58    ::MPI_Aint datasize, lb;
59    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
60
61    if(ep_rank_loc == 0 && mpi_rank != 0)
62    {
63      comm->my_buffer->void_buffer[0] = recvbuf;
64    }
65    if(ep_rank_loc == 0 && mpi_rank == 0)
66    {
67      comm->my_buffer->void_buffer[0] = const_cast<void*>(sendbuf); 
68    } 
69     
70
71    MPI_Barrier_local(comm);
72
73    memcpy(recvbuf, comm->my_buffer->void_buffer[0], datasize*count);
74
75    MPI_Barrier_local(comm);
76
77    comm->my_buffer->void_buffer[ep_rank_loc] = const_cast<void*>(sendbuf); 
78   
79    MPI_Barrier_local(comm);
80
81    if(op == MPI_SUM)
82    {
83      if(datatype == MPI_INT )
84      {
85        assert(datasize == sizeof(int));
86        for(int i=0; i<ep_rank_loc; i++)
87          reduce_sum<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
88      }
89     
90      else if(datatype == MPI_FLOAT )
91      {
92        assert(datasize == sizeof(float));
93        for(int i=0; i<ep_rank_loc; i++)
94          reduce_sum<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
95      }
96     
97
98      else if(datatype == MPI_DOUBLE )
99      {
100        assert(datasize == sizeof(double));
101        for(int i=0; i<ep_rank_loc; i++)
102          reduce_sum<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
103      }
104
105      else if(datatype == MPI_CHAR )
106      {
107        assert(datasize == sizeof(char));
108        for(int i=0; i<ep_rank_loc; i++)
109          reduce_sum<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
110      }
111
112      else if(datatype == MPI_LONG )
113      {
114        assert(datasize == sizeof(long));
115        for(int i=0; i<ep_rank_loc; i++)
116          reduce_sum<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
117      }
118
119      else if(datatype == MPI_UNSIGNED_LONG )
120      {
121        assert(datasize == sizeof(unsigned long));
122        for(int i=0; i<ep_rank_loc; i++)
123          reduce_sum<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
124      }
125     
126      else if(datatype == MPI_LONG_LONG_INT )
127      {
128        assert(datasize == sizeof(long long int));
129        for(int i=0; i<ep_rank_loc; i++)
130          reduce_sum<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
131      }
132
133      else 
134      {
135        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
136        MPI_Abort(comm, 0);
137      }
138
139     
140    }
141
142    else if(op == MPI_MAX)
143    {
144      if(datatype == MPI_INT )
145      {
146        assert(datasize == sizeof(int));
147        for(int i=0; i<ep_rank_loc; i++)
148          reduce_max<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
149      }
150
151      else if(datatype == MPI_FLOAT )
152      {
153        assert(datasize == sizeof(float));
154        for(int i=0; i<ep_rank_loc; i++)
155          reduce_max<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
156      }
157
158      else if(datatype == MPI_DOUBLE )
159      {
160        assert(datasize == sizeof(double));
161        for(int i=0; i<ep_rank_loc; i++)
162          reduce_max<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
163      }
164
165      else if(datatype == MPI_CHAR )
166      {
167        assert(datasize == sizeof(char));
168        for(int i=0; i<ep_rank_loc; i++)
169          reduce_max<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
170      }
171
172      else if(datatype == MPI_LONG )
173      {
174        assert(datasize == sizeof(long));
175        for(int i=0; i<ep_rank_loc; i++)
176          reduce_max<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
177      }
178
179      else if(datatype == MPI_UNSIGNED_LONG )
180      {
181        assert(datasize == sizeof(unsigned long));
182        for(int i=0; i<ep_rank_loc; i++)
183          reduce_max<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
184      }
185     
186      else if(datatype == MPI_LONG_LONG_INT )
187      {
188        assert(datasize == sizeof(long long int));
189        for(int i=0; i<ep_rank_loc; i++)
190          reduce_max<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
191      }
192
193      else 
194      {
195        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
196        MPI_Abort(comm, 0);
197      }
198    }
199
200    else if(op == MPI_MIN)
201    {
202      if(datatype == MPI_INT )
203      {
204        assert(datasize == sizeof(int));
205        for(int i=0; i<ep_rank_loc; i++)
206          reduce_min<int>(static_cast<int*>(comm->my_buffer->void_buffer[i]), static_cast<int*>(recvbuf), count);   
207      }
208
209      else if(datatype == MPI_FLOAT )
210      {
211        assert(datasize == sizeof(float));
212        for(int i=0; i<ep_rank_loc; i++)
213          reduce_min<float>(static_cast<float*>(comm->my_buffer->void_buffer[i]), static_cast<float*>(recvbuf), count);   
214      }
215
216      else if(datatype == MPI_DOUBLE )
217      {
218        assert(datasize == sizeof(double));
219        for(int i=0; i<ep_rank_loc; i++)
220          reduce_min<double>(static_cast<double*>(comm->my_buffer->void_buffer[i]), static_cast<double*>(recvbuf), count);
221      }
222
223      else if(datatype == MPI_CHAR )
224      {
225        assert(datasize == sizeof(char));
226        for(int i=0; i<ep_rank_loc; i++)
227          reduce_min<char>(static_cast<char*>(comm->my_buffer->void_buffer[i]), static_cast<char*>(recvbuf), count);
228      }
229
230      else if(datatype == MPI_LONG )
231      {
232        assert(datasize == sizeof(long));
233        for(int i=0; i<ep_rank_loc; i++)
234          reduce_min<long>(static_cast<long*>(comm->my_buffer->void_buffer[i]), static_cast<long*>(recvbuf), count);
235      }
236
237      else if(datatype == MPI_UNSIGNED_LONG )
238      {
239        assert(datasize == sizeof(unsigned long));
240        for(int i=0; i<ep_rank_loc; i++)
241          reduce_min<unsigned long>(static_cast<unsigned long*>(comm->my_buffer->void_buffer[i]), static_cast<unsigned long*>(recvbuf), count);   
242      }
243
244      else if(datatype == MPI_LONG_LONG_INT )
245      {
246        assert(datasize == sizeof(long long int));
247        for(int i=0; i<ep_rank_loc; i++)
248          reduce_min<long long int>(static_cast<long long int*>(comm->my_buffer->void_buffer[i]), static_cast<long long int*>(recvbuf), count);   
249      }
250
251      else 
252      {
253        printf("datatype Error in ep_exscan : INT, FLOAT, DOUBLE, CHAR, LONG, UNSIGNED_LONG, LONG_LONG_INT\n");
254        MPI_Abort(comm, 0);
255      }
256    }
257   
258    else
259    {
260      printf("op type Error in ep_exscan : MPI_MAX, MPI_MIN, MPI_SUM\n");
261      MPI_Abort(comm, 0);
262    }
263
264    MPI_Barrier_local(comm);
265
266  }
267
268  int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
269  {
270    if(!comm->is_ep) return ::MPI_Exscan(sendbuf, recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
271    if(comm->is_intercomm) return MPI_Exscan_intercomm(sendbuf, recvbuf, count, datatype, op, comm);
272   
273    assert(valid_type(datatype));
274    assert(valid_op(op));
275
276    int ep_rank = comm->ep_comm_ptr->size_rank_info[0].first;
277    int ep_rank_loc = comm->ep_comm_ptr->size_rank_info[1].first;
278    int mpi_rank = comm->ep_comm_ptr->size_rank_info[2].first;
279    int ep_size = comm->ep_comm_ptr->size_rank_info[0].second;
280    int num_ep = comm->ep_comm_ptr->size_rank_info[1].second;
281    int mpi_size = comm->ep_comm_ptr->size_rank_info[2].second;
282
283    ::MPI_Aint datasize, lb;
284    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize);
285   
286    void* tmp_sendbuf;
287    tmp_sendbuf = new void*[datasize * count];
288
289    int my_src = 0;
290    int my_dst = ep_rank;
291
292    std::vector<int> my_map(mpi_size, 0);
293
294    for(int i=0; i<comm->ep_rank_map->size(); i++) my_map[comm->ep_rank_map->at(i).second]++;
295
296    for(int i=0; i<mpi_rank; i++) my_src += my_map[i];
297    my_src += ep_rank_loc;
298
299     
300    for(int i=0; i<mpi_size; i++)
301    {
302      if(my_dst < my_map[i])
303      {
304        my_dst = get_ep_rank(comm, my_dst, i); 
305        break;
306      }
307      else
308        my_dst -= my_map[i];
309    }
310
311    if(ep_rank != my_dst) 
312    {
313      MPI_Request request[2];
314      MPI_Status status[2];
315
316      MPI_Isend(sendbuf,     count, datatype, my_dst, my_dst,  comm, &request[0]);
317   
318      MPI_Irecv(tmp_sendbuf, count, datatype, my_src, ep_rank, comm, &request[1]);
319   
320      MPI_Waitall(2, request, status);
321    }
322
323    else memcpy(tmp_sendbuf, sendbuf, datasize*count);
324   
325
326    void* tmp_recvbuf;
327    tmp_recvbuf = new void*[datasize * count];   
328
329    MPI_Reduce_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, 0, comm);
330
331    if(ep_rank_loc == 0)
332    {
333      ::MPI_Exscan(MPI_IN_PLACE, tmp_recvbuf, count, to_mpi_type(datatype), to_mpi_op(op), to_mpi_comm(comm->mpi_comm));
334    }
335   
336    MPI_Exscan_local(tmp_sendbuf, tmp_recvbuf, count, datatype, op, comm);
337
338
339    if(ep_rank != my_src) 
340    {
341      MPI_Request request[2];
342      MPI_Status status[2];
343
344      MPI_Isend(tmp_recvbuf, count, datatype, my_src, my_src,  comm, &request[0]);
345   
346      MPI_Irecv(recvbuf,     count, datatype, my_dst, ep_rank, comm, &request[1]);
347   
348      MPI_Waitall(2, request, status);
349    }
350
351    else memcpy(recvbuf, tmp_recvbuf, datasize*count);
352
353    delete[] tmp_sendbuf;
354    delete[] tmp_recvbuf;
355
356  }
357
358
359  int MPI_Exscan_intercomm(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
360  {
361    printf("MPI_Exscan_intercomm not yet implemented\n");
362    MPI_Abort(comm, 0);
363  }
364
365}
366#endif
Note: See TracBrowser for help on using the repository browser.