source: XIOS/trunk/extern/src_ep/ep_gatherv.cpp @ 1034

Last change on this file since 1034 was 1034, checked in by yushan, 7 years ago

adding src_ep into extern folder

File size: 16.6 KB
Line 
1/*!
2   \file ep_gather.cpp
3   \since 2 may 2016
4
5   \brief Definitions of MPI collective function: MPI_Gatherv, MPI_Allgatherv
6 */
7
8#include "ep_lib.hpp"
9#include <mpi.h>
10#include "ep_declaration.hpp"
11
12
13using namespace std;
14
15namespace ep_lib
16{
17  int MPI_Gatherv_local(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
18  {
19    if(datatype == MPI_INT)
20    {
21      Debug("datatype is INT\n");
22      return MPI_Gatherv_local_int(sendbuf, count, recvbuf, recvcounts, displs, comm);
23    }
24    else if(datatype == MPI_FLOAT)
25    {
26      Debug("datatype is FLOAT\n");
27      return MPI_Gatherv_local_float(sendbuf, count, recvbuf, recvcounts, displs, comm);
28    }
29    else if(datatype == MPI_DOUBLE)
30    {
31      Debug("datatype is DOUBLE\n");
32      return MPI_Gatherv_local_double(sendbuf, count, recvbuf, recvcounts, displs, comm);
33    }
34    else if(datatype == MPI_LONG)
35    {
36      Debug("datatype is LONG\n");
37      return MPI_Gatherv_local_long(sendbuf, count, recvbuf, recvcounts, displs, comm);
38    }
39    else if(datatype == MPI_UNSIGNED_LONG)
40    {
41      Debug("datatype is uLONG\n");
42      return MPI_Gatherv_local_ulong(sendbuf, count, recvbuf, recvcounts, displs, comm);
43    }
44    else if(datatype == MPI_CHAR)
45    {
46      Debug("datatype is CHAR\n");
47      return MPI_Gatherv_local_char(sendbuf, count, recvbuf, recvcounts, displs, comm);
48    }
49    else
50    {
51      printf("MPI_Gatherv Datatype not supported!\n");
52      exit(0);
53    }
54  }
55
56  int MPI_Gatherv_local_int(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
57  {
58    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
59    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
60
61    int *buffer = comm.my_buffer->buf_int;
62    int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf));
63    int *recv_buf = static_cast<int*>(recvbuf);
64
65    if(my_rank == 0)
66    {
67      assert(count == recvcounts[0]);
68      copy(send_buf, send_buf+count, recv_buf + displs[0]);
69    }
70
71    for(int j=0; j<count; j+=BUFFER_SIZE)
72    {
73      for(int k=1; k<num_ep; k++)
74      {
75        if(my_rank == k)
76        {
77          #pragma omp critical (write_to_buffer)
78          {
79            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
80            #pragma omp flush
81          }
82        }
83
84        MPI_Barrier_local(comm);
85
86        if(my_rank == 0)
87        {
88          #pragma omp flush
89          #pragma omp critical (read_from_buffer)
90          {
91            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+displs[k]);
92          }
93        }
94
95        MPI_Barrier_local(comm);
96      }
97    }
98  }
99
100  int MPI_Gatherv_local_float(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
101  {
102    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
103    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
104
105    float *buffer = comm.my_buffer->buf_float;
106    float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf));
107    float *recv_buf = static_cast<float*>(recvbuf);
108
109    if(my_rank == 0)
110    {
111      assert(count == recvcounts[0]);
112      copy(send_buf, send_buf+count, recv_buf + displs[0]);
113    }
114
115    for(int j=0; j<count; j+=BUFFER_SIZE)
116    {
117      for(int k=1; k<num_ep; k++)
118      {
119        if(my_rank == k)
120        {
121          #pragma omp critical (write_to_buffer)
122          {
123            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
124            #pragma omp flush
125          }
126        }
127
128        MPI_Barrier_local(comm);
129
130        if(my_rank == 0)
131        {
132          #pragma omp flush
133          #pragma omp critical (read_from_buffer)
134          {
135            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+displs[k]);
136          }
137        }
138
139        MPI_Barrier_local(comm);
140      }
141    }
142  }
143
144  int MPI_Gatherv_local_double(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
145  {
146    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
147    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
148
149    double *buffer = comm.my_buffer->buf_double;
150    double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf));
151    double *recv_buf = static_cast<double*>(recvbuf);
152
153    if(my_rank == 0)
154    {
155      assert(count == recvcounts[0]);
156      copy(send_buf, send_buf+count, recv_buf + displs[0]);
157    }
158
159    for(int j=0; j<count; j+=BUFFER_SIZE)
160    {
161      for(int k=1; k<num_ep; k++)
162      {
163        if(my_rank == k)
164        {
165          #pragma omp critical (write_to_buffer)
166          {
167            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
168            #pragma omp flush
169          }
170        }
171
172        MPI_Barrier_local(comm);
173
174        if(my_rank == 0)
175        {
176          #pragma omp flush
177          #pragma omp critical (read_from_buffer)
178          {
179            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+displs[k]);
180          }
181        }
182
183        MPI_Barrier_local(comm);
184      }
185    }
186  }
187
188  int MPI_Gatherv_local_long(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
189  {
190    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
191    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
192
193    long *buffer = comm.my_buffer->buf_long;
194    long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf));
195    long *recv_buf = static_cast<long*>(recvbuf);
196
197    if(my_rank == 0)
198    {
199      assert(count == recvcounts[0]);
200      copy(send_buf, send_buf+count, recv_buf + displs[0]);
201    }
202
203    for(int j=0; j<count; j+=BUFFER_SIZE)
204    {
205      for(int k=1; k<num_ep; k++)
206      {
207        if(my_rank == k)
208        {
209          #pragma omp critical (write_to_buffer)
210          {
211            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
212            #pragma omp flush
213          }
214        }
215
216        MPI_Barrier_local(comm);
217
218        if(my_rank == 0)
219        {
220          #pragma omp flush
221          #pragma omp critical (read_from_buffer)
222          {
223            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+displs[k]);
224          }
225        }
226
227        MPI_Barrier_local(comm);
228      }
229    }
230  }
231
232  int MPI_Gatherv_local_ulong(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
233  {
234    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
235    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
236
237    unsigned long *buffer = comm.my_buffer->buf_ulong;
238    unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf));
239    unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf);
240
241    if(my_rank == 0)
242    {
243      assert(count == recvcounts[0]);
244      copy(send_buf, send_buf+count, recv_buf + displs[0]);
245    }
246
247    for(int j=0; j<count; j+=BUFFER_SIZE)
248    {
249      for(int k=1; k<num_ep; k++)
250      {
251        if(my_rank == k)
252        {
253          #pragma omp critical (write_to_buffer)
254          {
255            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
256            #pragma omp flush
257          }
258        }
259
260        MPI_Barrier_local(comm);
261
262        if(my_rank == 0)
263        {
264          #pragma omp flush
265          #pragma omp critical (read_from_buffer)
266          {
267            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+displs[k]);
268          }
269        }
270
271        MPI_Barrier_local(comm);
272      }
273    }
274  }
275
276  int MPI_Gatherv_local_char(const void *sendbuf, int count, void *recvbuf, const int recvcounts[], const int displs[], MPI_Comm comm)
277  {
278    int my_rank = comm.ep_comm_ptr->size_rank_info[1].first;
279    int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second;
280
281    char *buffer = comm.my_buffer->buf_char;
282    char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf));
283    char *recv_buf = static_cast<char*>(recvbuf);
284
285    if(my_rank == 0)
286    {
287      assert(count == recvcounts[0]);
288      copy(send_buf, send_buf+count, recv_buf + displs[0]);
289    }
290
291    for(int j=0; j<count; j+=BUFFER_SIZE)
292    {
293      for(int k=1; k<num_ep; k++)
294      {
295        if(my_rank == k)
296        {
297          #pragma omp critical (write_to_buffer)
298          {
299            copy(send_buf+j, send_buf + min(BUFFER_SIZE, count-j) , buffer);
300            #pragma omp flush
301          }
302        }
303
304        MPI_Barrier_local(comm);
305
306        if(my_rank == 0)
307        {
308          #pragma omp flush
309          #pragma omp critical (read_from_buffer)
310          {
311            copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j+displs[k]);
312          }
313        }
314
315        MPI_Barrier_local(comm);
316      }
317    }
318  }
319
320
321  int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
322                  MPI_Datatype recvtype, int root, MPI_Comm comm)
323  {
324 
325   
326   
327
328    if(!comm.is_ep && comm.mpi_comm)
329    {
330      #ifdef _serialized
331      #pragma omp critical (_mpi_call)
332      #endif // _serialized
333      ::MPI_Gatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
334                    static_cast< ::MPI_Datatype>(recvtype), root, static_cast< ::MPI_Comm>(comm.mpi_comm));
335      return 0;
336    }
337
338    if(!comm.mpi_comm) return 0;
339
340    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
341
342    MPI_Datatype datatype = sendtype;
343    int count = sendcount;
344
345    int ep_rank, ep_rank_loc, mpi_rank;
346    int ep_size, num_ep, mpi_size;
347
348    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
349    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
350    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
351    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
352    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
353    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
354   
355    MPI_Bcast(const_cast< int* >(recvcounts), ep_size, MPI_INT, root, comm);
356    MPI_Bcast(const_cast< int* >(displs), ep_size, MPI_INT, root, comm);
357
358
359    int root_mpi_rank = comm.rank_map->at(root).second;
360    int root_ep_loc = comm.rank_map->at(root).first;
361
362
363    ::MPI_Aint datasize, lb;
364    #ifdef _serialized
365    #pragma omp critical (_mpi_call)
366    #endif // _serialized
367    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
368
369    void *local_gather_recvbuf;
370
371    if(ep_rank_loc==0)
372    {
373      int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
374      local_gather_recvbuf = new void*[datasize*buffer_size];
375    }
376
377    // local gather to master
378    int local_displs[num_ep];
379    local_displs[0] = 0;
380    for(int i=1; i<num_ep; i++)
381    {
382      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
383    }
384    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm);
385
386    //MPI_Gather
387    if(ep_rank_loc == 0)
388    {
389
390      int gatherv_recvcnt[mpi_size];
391      int gatherv_displs[mpi_size];
392      int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
393
394      //gatherv_recvcnt = new int[mpi_size];
395      //gatherv_displs = new int[mpi_size];
396
397      #ifdef _serialized
398      #pragma omp critical (_mpi_call)
399      #endif // _serialized
400      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
401
402      gatherv_displs[0] = 0;
403      for(int i=1; i<mpi_size; i++)
404      {
405        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
406      }
407
408      #ifdef _serialized
409      #pragma omp critical (_mpi_call)
410      #endif // _serialized
411      ::MPI_Gatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
412                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm));
413
414      //delete[] gatherv_recvcnt;
415      //delete[] gatherv_displs;
416    }
417
418
419    if(root_ep_loc != 0 && mpi_rank == root_mpi_rank) // root is not master, master send to root and root receive from master
420    {
421      innode_memcpy(0, recvbuf, root_ep_loc, recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm);
422    }
423
424
425
426    if(ep_rank_loc==0)
427    {
428      if(datatype == MPI_INT)
429      {
430        delete[] static_cast<int*>(local_gather_recvbuf);
431      }
432      else if(datatype == MPI_FLOAT)
433      {
434        delete[] static_cast<float*>(local_gather_recvbuf);
435      }
436      else if(datatype == MPI_DOUBLE)
437      {
438        delete[] static_cast<double*>(local_gather_recvbuf);
439      }
440      else if(datatype == MPI_LONG)
441      {
442        delete[] static_cast<long*>(local_gather_recvbuf);
443      }
444      else if(datatype == MPI_UNSIGNED_LONG)
445      {
446        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
447      }
448      else // if(datatype == MPI_CHAR)
449      {
450        delete[] static_cast<char*>(local_gather_recvbuf);
451      }
452    }
453    return 0;
454  }
455
456
457
458  int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[],
459                  MPI_Datatype recvtype, MPI_Comm comm)
460  {
461
462    if(!comm.is_ep && comm.mpi_comm)
463    {
464      #ifdef _serialized
465      #pragma omp critical (_mpi_call)
466      #endif // _serialized
467      ::MPI_Allgatherv(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcounts, displs,
468                       static_cast< ::MPI_Datatype>(recvtype), static_cast< ::MPI_Comm>(comm.mpi_comm));
469      return 0;
470    }
471
472    if(!comm.mpi_comm) return 0;
473
474    assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype));
475
476
477    MPI_Datatype datatype = sendtype;
478    int count = sendcount;
479
480    int ep_rank, ep_rank_loc, mpi_rank;
481    int ep_size, num_ep, mpi_size;
482
483    ep_rank = comm.ep_comm_ptr->size_rank_info[0].first;
484    ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first;
485    mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first;
486    ep_size = comm.ep_comm_ptr->size_rank_info[0].second;
487    num_ep = comm.ep_comm_ptr->size_rank_info[1].second;
488    mpi_size = comm.ep_comm_ptr->size_rank_info[2].second;
489   
490
491    assert(accumulate(recvcounts, recvcounts+ep_size-1, 0) == displs[ep_size-1]); // Only for contunuous gather.
492
493
494    ::MPI_Aint datasize, lb;
495    #ifdef _serialized
496    #pragma omp critical (_mpi_call)
497    #endif // _serialized
498    ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize);
499
500    void *local_gather_recvbuf;
501
502    if(ep_rank_loc==0)
503    {
504      int buffer_size = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
505      local_gather_recvbuf = new void*[datasize*buffer_size];
506    }
507
508    // local gather to master
509    int local_displs[num_ep];
510    local_displs[0] = 0;
511    for(int i=1; i<num_ep; i++)
512    {
513      local_displs[i] = displs[ep_rank-ep_rank_loc+i]-displs[ep_rank-ep_rank_loc];
514    }
515    MPI_Gatherv_local(sendbuf, count, datatype, local_gather_recvbuf, recvcounts+ep_rank-ep_rank_loc, local_displs, comm);
516
517    //MPI_Gather
518    if(ep_rank_loc == 0)
519    {
520      int *gatherv_recvcnt;
521      int *gatherv_displs;
522      int gatherv_cnt = accumulate(recvcounts+ep_rank, recvcounts+ep_rank+num_ep, 0);
523
524      gatherv_recvcnt = new int[mpi_size];
525      gatherv_displs = new int[mpi_size];
526
527      #ifdef _serialized
528      #pragma omp critical (_mpi_call)
529      #endif // _serialized
530      ::MPI_Allgather(&gatherv_cnt, 1, MPI_INT_STD, gatherv_recvcnt, 1, MPI_INT_STD, static_cast< ::MPI_Comm>(comm.mpi_comm));
531      gatherv_displs[0] = displs[0];
532      for(int i=1; i<mpi_size; i++)
533      {
534        gatherv_displs[i] = gatherv_recvcnt[i-1] + gatherv_displs[i-1];
535      }
536      #ifdef _serialized
537      #pragma omp critical (_mpi_call)
538      #endif // _serialized
539      ::MPI_Allgatherv(local_gather_recvbuf, gatherv_cnt, static_cast< ::MPI_Datatype>(datatype), recvbuf, gatherv_recvcnt,
540                    gatherv_displs, static_cast< ::MPI_Datatype>(datatype), static_cast< ::MPI_Comm>(comm.mpi_comm));
541
542      delete[] gatherv_recvcnt;
543      delete[] gatherv_displs;
544    }
545
546    MPI_Bcast_local(recvbuf, accumulate(recvcounts, recvcounts+ep_size, 0), datatype, comm);
547
548    if(ep_rank_loc==0)
549    {
550      if(datatype == MPI_INT)
551      {
552        delete[] static_cast<int*>(local_gather_recvbuf);
553      }
554      else if(datatype == MPI_FLOAT)
555      {
556        delete[] static_cast<float*>(local_gather_recvbuf);
557      }
558      else if(datatype == MPI_DOUBLE)
559      {
560        delete[] static_cast<double*>(local_gather_recvbuf);
561      }
562      else if(datatype == MPI_LONG)
563      {
564        delete[] static_cast<long*>(local_gather_recvbuf);
565      }
566      else if(datatype == MPI_UNSIGNED_LONG)
567      {
568        delete[] static_cast<unsigned long*>(local_gather_recvbuf);
569      }
570      else // if(datatype == MPI_CHAR)
571      {
572        delete[] static_cast<char*>(local_gather_recvbuf);
573      }
574    }
575  }
576
577
578}
Note: See TracBrowser for help on using the repository browser.