Ignore:
Timestamp:
10/06/17 13:56:33 (7 years ago)
Author:
yushan
Message:

EP update all

File:
1 edited

Legend:

Unmodified
Added
Removed
  • XIOS/dev/branch_openmp/extern/src_ep_dev/ep_scatter.cpp

    r1289 r1295  
    99#include <mpi.h> 
    1010#include "ep_declaration.hpp" 
     11#include "ep_mpi.hpp" 
    1112 
    1213using namespace std; 
     
    1516{ 
    1617 
    17   int MPI_Scatter_local2(const void *sendbuf, int count, MPI_Datatype datatype, void *recvbuf, MPI_Comm comm) 
     18  int MPI_Scatter_local(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int local_root, MPI_Comm comm) 
    1819  { 
    19     if(datatype == MPI_INT) 
    20     { 
    21       Debug("datatype is INT\n"); 
    22       return MPI_Scatter_local_int(sendbuf, count, recvbuf, comm); 
    23     } 
    24     else if(datatype == MPI_FLOAT) 
    25     { 
    26       Debug("datatype is FLOAT\n"); 
    27       return MPI_Scatter_local_float(sendbuf, count, recvbuf, comm); 
    28     } 
    29     else if(datatype == MPI_DOUBLE) 
    30     { 
    31       Debug("datatype is DOUBLE\n"); 
    32       return MPI_Scatter_local_double(sendbuf, count, recvbuf, comm); 
    33     } 
    34     else if(datatype == MPI_LONG) 
    35     { 
    36       Debug("datatype is LONG\n"); 
    37       return MPI_Scatter_local_long(sendbuf, count, recvbuf, comm); 
    38     } 
    39     else if(datatype == MPI_UNSIGNED_LONG) 
    40     { 
    41       Debug("datatype is uLONG\n"); 
    42       return MPI_Scatter_local_ulong(sendbuf, count, recvbuf, comm); 
    43     } 
    44     else if(datatype == MPI_CHAR) 
    45     { 
    46       Debug("datatype is CHAR\n"); 
    47       return MPI_Scatter_local_char(sendbuf, count, recvbuf, comm); 
    48     } 
    49     else 
    50     { 
    51       printf("MPI_Scatter Datatype not supported!\n"); 
    52       exit(0); 
    53     } 
    54   } 
     20    assert(valid_type(sendtype) && valid_type(recvtype)); 
     21    assert(recvcount == sendcount); 
    5522 
    56   int MPI_Scatter_local_int(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
    57   { 
    58     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    59     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
     23    ::MPI_Aint datasize, lb; 
     24    ::MPI_Type_get_extent(to_mpi_type(sendtype), &lb, &datasize); 
     25 
     26    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     27    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    6028 
    6129 
    62     int *buffer = comm.my_buffer->buf_int; 
    63     int *send_buf = static_cast<int*>(const_cast<void*>(sendbuf)); 
    64     int *recv_buf = static_cast<int*>(recvbuf); 
     30    if(ep_rank_loc == local_root) 
     31      comm.my_buffer->void_buffer[local_root] = const_cast<void*>(sendbuf); 
    6532 
    66     for(int k=0; k<num_ep; k++) 
    67     { 
    68       for(int j=0; j<count; j+=BUFFER_SIZE) 
    69       { 
    70         if(my_rank == 0) 
    71         { 
    72           #pragma omp critical (write_to_buffer) 
    73           { 
    74             copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
    75             #pragma omp flush 
    76           } 
    77         } 
     33    MPI_Barrier_local(comm); 
    7834 
    79         MPI_Barrier_local(comm); 
     35    #pragma omp critical (_scatter)       
     36    memcpy(recvbuf, comm.my_buffer->void_buffer[local_root]+datasize*ep_rank_loc*sendcount, datasize * recvcount); 
     37     
    8038 
    81         if(my_rank == k) 
    82         { 
    83           #pragma omp critical (read_from_buffer) 
    84           { 
    85             #pragma omp flush 
    86             copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    87           } 
    88         } 
    89         MPI_Barrier_local(comm); 
    90       } 
    91     } 
     39    MPI_Barrier_local(comm); 
    9240  } 
    93  
    94   int MPI_Scatter_local_float(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
    95   { 
    96     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    97     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    98  
    99     float *buffer = comm.my_buffer->buf_float; 
    100     float *send_buf = static_cast<float*>(const_cast<void*>(sendbuf)); 
    101     float *recv_buf = static_cast<float*>(recvbuf); 
    102  
    103     for(int k=0; k<num_ep; k++) 
    104     { 
    105       for(int j=0; j<count; j+=BUFFER_SIZE) 
    106       { 
    107         if(my_rank == 0) 
    108         { 
    109           #pragma omp critical (write_to_buffer) 
    110           { 
    111             copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
    112             #pragma omp flush 
    113           } 
    114         } 
    115  
    116         MPI_Barrier_local(comm); 
    117  
    118         if(my_rank == k) 
    119         { 
    120           #pragma omp critical (read_from_buffer) 
    121           { 
    122             #pragma omp flush 
    123             copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    124           } 
    125         } 
    126         MPI_Barrier_local(comm); 
    127       } 
    128     } 
    129   } 
    130  
    131   int MPI_Scatter_local_double(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
    132   { 
    133     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    134     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    135  
    136     double *buffer = comm.my_buffer->buf_double; 
    137     double *send_buf = static_cast<double*>(const_cast<void*>(sendbuf)); 
    138     double *recv_buf = static_cast<double*>(recvbuf); 
    139  
    140     for(int k=0; k<num_ep; k++) 
    141     { 
    142       for(int j=0; j<count; j+=BUFFER_SIZE) 
    143       { 
    144         if(my_rank == 0) 
    145         { 
    146           #pragma omp critical (write_to_buffer) 
    147           { 
    148             copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
    149             #pragma omp flush 
    150           } 
    151         } 
    152  
    153         MPI_Barrier_local(comm); 
    154  
    155         if(my_rank == k) 
    156         { 
    157           #pragma omp critical (read_from_buffer) 
    158           { 
    159             #pragma omp flush 
    160             copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    161           } 
    162         } 
    163         MPI_Barrier_local(comm); 
    164       } 
    165     } 
    166   } 
    167  
    168   int MPI_Scatter_local_long(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
    169   { 
    170     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    171     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    172  
    173     long *buffer = comm.my_buffer->buf_long; 
    174     long *send_buf = static_cast<long*>(const_cast<void*>(sendbuf)); 
    175     long *recv_buf = static_cast<long*>(recvbuf); 
    176  
    177     for(int k=0; k<num_ep; k++) 
    178     { 
    179       for(int j=0; j<count; j+=BUFFER_SIZE) 
    180       { 
    181         if(my_rank == 0) 
    182         { 
    183           #pragma omp critical (write_to_buffer) 
    184           { 
    185             copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
    186             #pragma omp flush 
    187           } 
    188         } 
    189  
    190         MPI_Barrier_local(comm); 
    191  
    192         if(my_rank == k) 
    193         { 
    194           #pragma omp critical (read_from_buffer) 
    195           { 
    196             #pragma omp flush 
    197             copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    198           } 
    199         } 
    200         MPI_Barrier_local(comm); 
    201       } 
    202     } 
    203   } 
    204  
    205  
    206   int MPI_Scatter_local_ulong(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
    207   { 
    208     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    209     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    210  
    211     unsigned long *buffer = comm.my_buffer->buf_ulong; 
    212     unsigned long *send_buf = static_cast<unsigned long*>(const_cast<void*>(sendbuf)); 
    213     unsigned long *recv_buf = static_cast<unsigned long*>(recvbuf); 
    214  
    215     for(int k=0; k<num_ep; k++) 
    216     { 
    217       for(int j=0; j<count; j+=BUFFER_SIZE) 
    218       { 
    219         if(my_rank == 0) 
    220         { 
    221           #pragma omp critical (write_to_buffer) 
    222           { 
    223             copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
    224             #pragma omp flush 
    225           } 
    226         } 
    227  
    228         MPI_Barrier_local(comm); 
    229  
    230         if(my_rank == k) 
    231         { 
    232           #pragma omp critical (read_from_buffer) 
    233           { 
    234             #pragma omp flush 
    235             copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    236           } 
    237         } 
    238         MPI_Barrier_local(comm); 
    239       } 
    240     } 
    241   } 
    242  
    243  
    244   int MPI_Scatter_local_char(const void *sendbuf, int count, void *recvbuf, MPI_Comm comm) 
    245   { 
    246     int my_rank = comm.ep_comm_ptr->size_rank_info[1].first; 
    247     int num_ep  = comm.ep_comm_ptr->size_rank_info[1].second; 
    248  
    249     char *buffer = comm.my_buffer->buf_char; 
    250     char *send_buf = static_cast<char*>(const_cast<void*>(sendbuf)); 
    251     char *recv_buf = static_cast<char*>(recvbuf); 
    252  
    253     for(int k=0; k<num_ep; k++) 
    254     { 
    255       for(int j=0; j<count; j+=BUFFER_SIZE) 
    256       { 
    257         if(my_rank == 0) 
    258         { 
    259           #pragma omp critical (write_to_buffer) 
    260           { 
    261             copy(send_buf+k*count+j, send_buf+k*count+j+min(BUFFER_SIZE, count-j), buffer); 
    262             #pragma omp flush 
    263           } 
    264         } 
    265  
    266         MPI_Barrier_local(comm); 
    267  
    268         if(my_rank == k) 
    269         { 
    270           #pragma omp critical (read_from_buffer) 
    271           { 
    272             #pragma omp flush 
    273             copy(buffer, buffer+min(BUFFER_SIZE, count-j), recv_buf+j); 
    274           } 
    275         } 
    276         MPI_Barrier_local(comm); 
    277       } 
    278     } 
    279   } 
    280  
    281  
    282  
    283  
    28441 
    28542  int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) 
     
    28744    if(!comm.is_ep) 
    28845    { 
    289       ::MPI_Scatter(sendbuf, sendcount, static_cast< ::MPI_Datatype>(sendtype), recvbuf, recvcount, static_cast< ::MPI_Datatype>(recvtype), 
    290                     root, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    291       return 0; 
     46      return ::MPI_Scatter(sendbuf, sendcount, to_mpi_type(sendtype), recvbuf, recvcount, to_mpi_type(recvtype), root, to_mpi_comm(comm.mpi_comm)); 
    29247    } 
     48    
     49    assert(sendcount == recvcount); 
    29350 
    294     if(!comm.mpi_comm) return 0; 
    295  
    296     assert(static_cast< ::MPI_Datatype>(sendtype) == static_cast< ::MPI_Datatype>(recvtype) && sendcount == recvcount); 
     51    int ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
     52    int ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
     53    int mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
     54    int ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
     55    int num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
     56    int mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    29757 
    29858    int root_mpi_rank = comm.rank_map->at(root).second; 
    29959    int root_ep_loc = comm.rank_map->at(root).first; 
    30060 
    301     int ep_rank, ep_rank_loc, mpi_rank; 
    302     int ep_size, num_ep, mpi_size; 
    303  
    304     ep_rank = comm.ep_comm_ptr->size_rank_info[0].first; 
    305     ep_rank_loc = comm.ep_comm_ptr->size_rank_info[1].first; 
    306     mpi_rank = comm.ep_comm_ptr->size_rank_info[2].first; 
    307     ep_size = comm.ep_comm_ptr->size_rank_info[0].second; 
    308     num_ep = comm.ep_comm_ptr->size_rank_info[1].second; 
    309     mpi_size = comm.ep_comm_ptr->size_rank_info[2].second; 
    310  
     61    bool is_master = (ep_rank_loc==0 && mpi_rank != root_mpi_rank ) || ep_rank == root; 
     62    bool is_root = ep_rank == root; 
    31163 
    31264    MPI_Datatype datatype = sendtype; 
     
    31466 
    31567    ::MPI_Aint datasize, lb; 
     68    ::MPI_Type_get_extent(to_mpi_type(datatype), &lb, &datasize); 
     69     
     70    void *tmp_sendbuf; 
     71    if(is_root) tmp_sendbuf = new void*[ep_size * count * datasize]; 
    31672 
    317     ::MPI_Type_get_extent(static_cast< ::MPI_Datatype>(datatype), &lb, &datasize); 
     73    // reorder tmp_sendbuf 
     74    std::vector<int>local_ranks(num_ep); 
     75    std::vector<int>ranks(ep_size); 
     76 
     77    if(mpi_rank == root_mpi_rank) MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), root_ep_loc, comm); 
     78    else                          MPI_Gather_local(&ep_rank, 1, MPI_INT, local_ranks.data(), 0, comm); 
    31879 
    31980 
    320     void *master_sendbuf; 
    321     void *local_recvbuf; 
     81    std::vector<int> recvcounts(mpi_size, 0); 
     82    std::vector<int> displs(mpi_size, 0); 
    32283 
    323     if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) 
     84 
     85    if(is_master) 
    32486    { 
    325       if(ep_rank_loc == 0) master_sendbuf = new void*[datasize*count*ep_size]; 
     87      for(int i=0; i<ep_size; i++) 
     88      { 
     89        recvcounts[comm.rank_map->at(i).second]++; 
     90      } 
    32691 
    327       innode_memcpy(root_ep_loc, sendbuf, 0, master_sendbuf, count*ep_size, datatype, comm); 
     92      for(int i=1; i<mpi_size; i++) 
     93        displs[i] = displs[i-1] + recvcounts[i-1]; 
     94 
     95      ::MPI_Gatherv(local_ranks.data(), num_ep, MPI_INT, ranks.data(), recvcounts.data(), displs.data(), MPI_INT, root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
    32896    } 
    32997 
    33098 
    33199 
    332     if(ep_rank_loc == 0) 
     100    // if(is_root) printf("\nranks = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", ranks[0], ranks[1], ranks[2], ranks[3], ranks[4], ranks[5], ranks[6], ranks[7],  
     101    //                                                                                   ranks[8], ranks[9], ranks[10], ranks[11], ranks[12], ranks[13], ranks[14], ranks[15]); 
     102 
     103    if(is_root) 
     104    for(int i=0; i<ep_size; i++) 
    333105    { 
    334       int mpi_sendcnt = count*num_ep; 
    335       int mpi_scatterv_sendcnt[mpi_size]; 
    336       int displs[mpi_size]; 
     106      memcpy(tmp_sendbuf + i*datasize*count, sendbuf + ranks[i]*datasize*count, count*datasize); 
     107    } 
    337108 
    338       local_recvbuf = new void*[datasize*mpi_sendcnt]; 
    339  
    340       ::MPI_Gather(&mpi_sendcnt, 1, MPI_INT, mpi_scatterv_sendcnt, 1, MPI_INT, root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    341  
    342       displs[0] = 0; 
    343       for(int i=1; i<mpi_size; i++) 
    344         displs[i] = displs[i-1] + mpi_scatterv_sendcnt[i-1]; 
     109    // if(is_root) printf("\ntmp_sendbuf = %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n", static_cast<int*>(tmp_sendbuf)[0], static_cast<int*>(tmp_sendbuf)[2], static_cast<int*>(tmp_sendbuf)[4], static_cast<int*>(tmp_sendbuf)[6],  
     110    //                                                                           static_cast<int*>(tmp_sendbuf)[8], static_cast<int*>(tmp_sendbuf)[10], static_cast<int*>(tmp_sendbuf)[12], static_cast<int*>(tmp_sendbuf)[14],  
     111    //                                                                           static_cast<int*>(tmp_sendbuf)[16], static_cast<int*>(tmp_sendbuf)[18], static_cast<int*>(tmp_sendbuf)[20], static_cast<int*>(tmp_sendbuf)[22],  
     112    //                                                                           static_cast<int*>(tmp_sendbuf)[24], static_cast<int*>(tmp_sendbuf)[26], static_cast<int*>(tmp_sendbuf)[28], static_cast<int*>(tmp_sendbuf)[30] ); 
    345113 
    346114 
    347       if(root_ep_loc!=0) 
    348       { 
    349         ::MPI_Scatterv(master_sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 
    350                      local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    351       } 
    352       else 
    353       { 
    354         ::MPI_Scatterv(sendbuf, mpi_scatterv_sendcnt, displs, static_cast< ::MPI_Datatype>(datatype), 
    355                      local_recvbuf, mpi_sendcnt, static_cast< ::MPI_Datatype>(datatype), root_mpi_rank, static_cast< ::MPI_Comm>(comm.mpi_comm)); 
    356       } 
     115    // MPI_Scatterv from root to masters 
     116    void* local_recvbuf; 
     117    if(is_master) local_recvbuf = new void*[datasize * num_ep * count]; 
     118 
     119 
     120    if(is_master) 
     121    { 
     122      int local_sendcount = num_ep * count; 
     123      ::MPI_Gather(&local_sendcount, 1, to_mpi_type(MPI_INT), recvcounts.data(), 1, to_mpi_type(MPI_INT), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
     124       
     125      if(is_root) for(int i=1; i<mpi_size; i++) displs[i] = displs[i-1] + recvcounts[i-1]; 
     126 
     127      ::MPI_Scatterv(tmp_sendbuf, recvcounts.data(), displs.data(), to_mpi_type(sendtype), local_recvbuf, num_ep*count, to_mpi_type(recvtype), root_mpi_rank, to_mpi_comm(comm.mpi_comm)); 
     128 
     129      // printf("local_recvbuf = %d %d %d %d\n", static_cast<int*>(local_recvbuf)[0], static_cast<int*>(local_recvbuf)[1], static_cast<int*>(local_recvbuf)[2], static_cast<int*>(local_recvbuf)[3]); 
     130                                                          // static_cast<int*>(local_recvbuf)[4], static_cast<int*>(local_recvbuf)[5], static_cast<int*>(local_recvbuf)[6], static_cast<int*>(local_recvbuf)[7]); 
    357131    } 
    358132 
    359     MPI_Scatter_local2(local_recvbuf, count, datatype, recvbuf, comm); 
     133    if(mpi_rank == root_mpi_rank) MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, root_ep_loc, comm); 
     134    else                          MPI_Scatter_local(local_recvbuf, count, sendtype, recvbuf, recvcount, recvtype, 0, comm); 
    360135 
    361     if(ep_rank_loc == 0) 
    362     { 
    363       if(datatype == MPI_INT) 
    364       { 
    365         if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<int*>(master_sendbuf); 
    366         delete[] static_cast<int*>(local_recvbuf); 
    367       } 
    368       else if(datatype == MPI_FLOAT) 
    369       { 
    370         if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<float*>(master_sendbuf); 
    371         delete[] static_cast<float*>(local_recvbuf); 
    372       } 
    373       else if(datatype == MPI_DOUBLE) 
    374       { 
    375         if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<double*>(master_sendbuf); 
    376         delete[] static_cast<double*>(local_recvbuf); 
    377       } 
    378       else if(datatype == MPI_LONG) 
    379       { 
    380         if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<long*>(master_sendbuf); 
    381         delete[] static_cast<long*>(local_recvbuf); 
    382       } 
    383       else if(datatype == MPI_UNSIGNED_LONG) 
    384       { 
    385         if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<unsigned long*>(master_sendbuf); 
    386         delete[] static_cast<unsigned long*>(local_recvbuf); 
    387       } 
    388       else //if(datatype == MPI_DOUBLE) 
    389       { 
    390         if(root_ep_loc!=0 && mpi_rank == root_mpi_rank) delete[] static_cast<char*>(master_sendbuf); 
    391         delete[] static_cast<char*>(local_recvbuf); 
    392       } 
    393     } 
     136    if(is_root)   delete[] tmp_sendbuf; 
     137    if(is_master) delete[] local_recvbuf; 
     138  } 
    394139 
    395   } 
    396140} 
Note: See TracChangeset for help on using the changeset viewer.