source: codes/icosagcm/trunk/src/parallel/transfert_mpi.f90

Last change on this file was 1053, checked in by dubos, 4 years ago

trunk : simplify allocate_field -- tested on 4 GPUs (TBC)

File size: 33.5 KB
RevLine 
[963]1! Module for MPI communication of field halos
2! This module uses Fortran 2003 features : move_alloc intrinsic, pointer bounds remapping, allocatable type fields
3module transfert_mpi_mod
4  use abort_mod, only : dynamico_abort, abort_acc
5  use profiling_mod, only : enter_profile, exit_profile, register_id
6  use domain_mod, only : ndomain, ndomain_glo, domain, domain_glo, domloc_glo_ind, domglo_rank, domglo_loc_ind
[965]7  use field_mod, only : t_field, field_T, field_U, field_Z
[963]8  use transfert_request_mod
9  implicit none
10  private
[186]11
[963]12  ! Describes how to pack/unpack a message from a local domain to another
13  type t_local_submessage
14    integer :: src_ind_loc, dest_ind_loc ! index of local and remote domain
[998]15    integer :: npoints ! Number of cells to transfer (dim12)
[963]16    integer, allocatable :: displ_src(:) ! List of indexes to copy from domain src_ind_loc
17    integer, allocatable :: displ_dest(:) ! List of indexes to copy to domain dest_ind_loc
18    integer, allocatable :: sign(:) ! Sign change to be applied for vector requests
19  end type
[266]20
[963]21  ! Describes how to pack/unpack a message from a domain to another, and contains MPI buffer
22  type t_submessage
[999]23    integer :: ind_loc, remote_ind_glo, remote_rank ! index of local and remote domain
[998]24    integer :: npoints ! Number of cells to transfer (dim12)
[963]25    integer, allocatable :: displs(:) ! List of indexes to copy from field to buffer for each level
26    integer, allocatable :: sign(:) ! Sign change to be applied for vector requests
[999]27    integer :: mpi_buffer_displ = -1
[963]28  end type
[999]29 
[1003]30  type t_compact_submessages
31    integer :: npoints
32    integer, allocatable, dimension(:) :: field_ind, field_displ, sign, remote_rank, mpi_displ, level_offset
33  end type
34 
35  type t_compact_local_submessages
36    integer :: npoints
37    integer, allocatable, dimension(:) :: field_ind_src, field_displ_src, sign, field_ind_dest, field_displ_dest
38  end type
39 
[999]40  type mpi_buffer_t
41    integer :: n
42    real, allocatable :: buff(:)
43  end type
[266]44
[963]45  ! Describes how to exchange data for a field.
46  type t_message
47    type (t_field), pointer :: field(:) => null() ! Field to exchange
48    type (t_request), pointer :: request(:) => null() ! Type of message to send
49    type (t_local_submessage), pointer :: message_local(:) ! Local halo copies
50    type (t_submessage), pointer :: message_in(:) ! Messages to recieve from remote ranks and to copy back to the field
51    type (t_submessage), pointer :: message_out(:) ! Halos to copy to MPI buffer and to send to remote ranks
[999]52    type (mpi_buffer_t), pointer :: mpi_buffer_in(:)
53    type (mpi_buffer_t), pointer :: mpi_buffer_out(:)
[1003]54    type (t_compact_submessages), pointer :: message_in_compact
55    type (t_compact_submessages), pointer :: message_out_compact
56    type (t_compact_local_submessages), pointer  :: message_local_compact   
[963]57    integer, pointer :: mpi_requests_in(:) ! MPI requests used for message_in.
58    integer, pointer :: mpi_requests_out(:) ! MPI requests used for message_out.
59    ! NOTE : requests are persistant requests initialized in init_message. MPI_Start and MPI_Wait are then used to initiate and complete communications.
60    ! ex : Give mpi_requests_in(i) to MPI_Start to send the buffer contained in message_in(i)
61    integer :: send_seq ! Sequence number : send_seq is incremented each time send_message is called
62    integer :: wait_seq ! Sequence number : wait_seq is incremented each time wait_message is called
63    logical :: ondevice ! Ready to transfer ondevice field
64  end type t_message
[266]65
[963]66  public :: t_message, t_request, &
67    req_i1, req_e1_scal, req_e1_vect, &
68    req_i0, req_e0_scal, req_e0_vect, &
69    req_z1_scal, &
70    init_transfert, &
71    init_message, &
72    finalize_message, &
73    send_message, &
74    wait_message, &
75    test_message
[962]76
[963]77  ! ---- Private variables ----
78  ! Profiling id for mpi
79  integer :: profile_mpi, profile_mpi_copies, profile_mpi_waitall, profile_mpi_barrier
80contains
81  ! Initialize transfert : must be called before any other transfert_mpi routines
82  subroutine init_transfert
83    use mpi_mod, only : MPI_THREAD_SINGLE, MPI_THREAD_FUNNELED
84    use mpipara, only : mpi_threading_mode
85    use profiling_mod, only : register_id
86    logical, parameter :: profile_mpi_detail = .true.
[667]87
[963]88    !$omp master
89    ! Check requested threads support
90    if( mpi_threading_mode /= MPI_THREAD_SINGLE .and. mpi_threading_mode /= MPI_THREAD_FUNNELED ) call dynamico_abort("Only single and funneled threading mode are supported.")
[667]91
[963]92    ! Register profiling ids
93    call register_id("MPI", profile_mpi)
94    if( profile_mpi_detail ) then
95      call register_id("MPI_copies", profile_mpi_copies)
96      call register_id("MPI_waitall", profile_mpi_waitall)
97      call register_id("MPI_omp_barrier", profile_mpi_barrier)
98    else
99      profile_mpi_copies = profile_mpi
100      profile_mpi_waitall = profile_mpi
101      profile_mpi_barrier = profile_mpi
102    endif
[26]103
[963]104    ! Initialize requests
105    call init_all_requests()
106    !$omp end master
107    !$omp barrier
108  end subroutine
[26]109
[963]110  subroutine init_message(field, request, message, name)
111    use mpi_mod
112    use mpipara
113    type(t_field), pointer, intent(in) :: field(:)
114    type(t_request),pointer, intent(in) :: request(:)
115    type(t_message), target, intent(out) :: message ! Needs intent out for call to finalize_message
116    character(len=*), intent(in),optional :: name
117    integer, parameter :: INITIAL_ALLOC_SIZE = 10, GROW_FACTOR = 2
[26]118
[963]119    type(t_submessage) :: submessage_in, submessage_out
120    type(t_local_submessage) :: submessage_local
[1003]121    integer :: dim3, dim4, npoints, last_point
122    integer :: ind, ind_loc, remote_ind_glo, loc_ind_glo, i, k, remote_rank
[999]123    integer :: message_in_size, message_out_size, message_local_size, buffer_in_size, buffer_out_size
[963]124    type(t_local_submessage), allocatable :: message_local_tmp(:)
125    type(t_submessage), allocatable :: message_in_tmp(:), message_out_tmp(:)
[965]126    integer :: field_type
[148]127
[963]128    !$omp barrier
129    !$omp master
[953]130    !init off-device
131    message%ondevice=.false.
[963]132    message%send_seq = 0
133    message%wait_seq = 0
[965]134
135    if( request(1)%field_type /= field(1)%field_type ) call dynamico_abort( "init_message : field_type/request mismatch" )
136    field_type = request(1)%field_type
[1053]137   
[963]138    dim3 = size(field(1)%rval4d,2)
139    dim4 = size(field(1)%rval4d,3)
140    message%field => field
141    message%request => request
142    ! Create list of inbound/outbound/local messages
143    allocate(message_in_tmp(INITIAL_ALLOC_SIZE))
144    message_in_size=0
145    allocate(message_out_tmp(INITIAL_ALLOC_SIZE))
146    message_out_size=0
147    allocate(message_local_tmp(INITIAL_ALLOC_SIZE))
148    message_local_size=0
[999]149    do loc_ind_glo = 1, ndomain_glo
[963]150      do remote_ind_glo = 1, ndomain_glo
[999]151        if(domglo_rank(loc_ind_glo) == mpi_rank) then
152          ind_loc = domglo_loc_ind(loc_ind_glo)
153          if( domglo_rank(remote_ind_glo) == mpi_rank ) then ! If sending to local domain
154            if(request(ind_loc)%points_HtoB(remote_ind_glo)%npoints > 0 ) then ! Add only non-empty messages
155              ! Add local message ind_loc -> remote_ind_glo, aggregarting submessage_in and submessage_out into submessage_local
156              submessage_out = make_submessage( field_type, request(ind_loc)%points_HtoB(remote_ind_glo), &
157                                                ind_loc, remote_ind_glo, dim3, dim4, request(1)%vector )
158              submessage_in = make_submessage( field_type, request(domglo_loc_ind(remote_ind_glo))%points_BtoH(domloc_glo_ind(ind_loc)), &
159                                              domglo_loc_ind(remote_ind_glo), domloc_glo_ind(ind_loc), dim3, dim4, request(1)%vector)
160              submessage_local%src_ind_loc = ind_loc
161              submessage_local%dest_ind_loc = domglo_loc_ind(remote_ind_glo)
162              submessage_local%npoints = submessage_out%npoints
163              submessage_local%displ_src = submessage_out%displs
164              submessage_local%displ_dest = submessage_in%displs
165              submessage_local%sign = submessage_in%sign
166              ! Add to local message list
167              call array_append_local_submessage( message_local_tmp, message_local_size, submessage_local)
168            endif
169          else ! If remote domain
170            ! When data to send to remote_domain, add submessage in message%message_out
171            if( request(ind_loc)%points_HtoB(remote_ind_glo)%npoints > 0 ) then
172              submessage_out = make_submessage( field_type, request(ind_loc)%points_HtoB(remote_ind_glo), &
173                                                ind_loc, remote_ind_glo, dim3, dim4, request(1)%vector )
174              call array_append_submessage( message_out_tmp, message_out_size, submessage_out )
175            end if         
[963]176          end if
[999]177        end if
178      end do
179    end do
180    ! Recv and Send submessages are transposed to recieve and send in same order
181    ! We iterate over global domain index to match sends with recieves (local domains are not ordered like global domains)
182    do remote_ind_glo = 1, ndomain_glo
183      do loc_ind_glo = 1, ndomain_glo
184        if( (domglo_rank(loc_ind_glo) == mpi_rank) .and. (domglo_rank(remote_ind_glo) /= mpi_rank) ) then
185          ind_loc = domglo_loc_ind(loc_ind_glo)
[963]186          if( request(ind_loc)%points_BtoH(remote_ind_glo)%npoints > 0 ) then
[965]187            submessage_in = make_submessage( field_type, request(ind_loc)%points_BtoH(remote_ind_glo), &
[963]188                                             ind_loc, remote_ind_glo, dim3, dim4, request(1)%vector )
189            call array_append_submessage( message_in_tmp, message_in_size, submessage_in )
190          end if
191        end if
192      end do
193    end do
[999]194   
195   
[963]196    ! Trim message_xx_tmp and put it in message%message_xx
197    allocate(message%message_in(message_in_size)); message%message_in(:) = message_in_tmp(:message_in_size)
198    allocate(message%message_out(message_out_size)); message%message_out(:) = message_out_tmp(:message_out_size)
199    allocate(message%message_local(message_local_size)); message%message_local(:) = message_local_tmp(:message_local_size)
[953]200
[999]201    ! Allocate MPI buffers
202    allocate( message%mpi_buffer_in(0:mpi_size-1) )
203    allocate( message%mpi_buffer_out(0:mpi_size-1) )
204    do i = 0, mpi_size-1
205      buffer_in_size = dim3*dim4*sum( message%message_in(:)%npoints, message%message_in(:)%remote_rank == i )
206      buffer_out_size = dim3*dim4*sum( message%message_out(:)%npoints, message%message_out(:)%remote_rank == i )
207      !TODO : what if size == 0 ?
208      allocate( message%mpi_buffer_in(i)%buff( buffer_in_size ) )
209      allocate( message%mpi_buffer_out(i)%buff( buffer_out_size ) )
210      message%mpi_buffer_in(i)%n=0
211      message%mpi_buffer_out(i)%n=0
212    end do
213    ! Set offsets in submessages
[963]214    do i=1, size(message%message_out)
[999]215      remote_rank = message%message_out(i)%remote_rank
216      message%message_out(i)%mpi_buffer_displ = message%mpi_buffer_out(remote_rank)%n
217      message%mpi_buffer_out(remote_rank)%n = message%mpi_buffer_out(remote_rank)%n + message%message_out(i)%npoints*dim3*dim4
[963]218    end do
219    do i=1, size(message%message_in)
[999]220      remote_rank = message%message_in(i)%remote_rank
221      message%message_in(i)%mpi_buffer_displ = message%mpi_buffer_in(remote_rank)%n
222      message%mpi_buffer_in(remote_rank)%n = message%mpi_buffer_in(remote_rank)%n + message%message_in(i)%npoints*dim3*dim4
[963]223    end do
[999]224    ! Create persistant MPI requests
225    allocate( message%mpi_requests_in(0:mpi_size-1) )
226    allocate( message%mpi_requests_out(0:mpi_size-1) )
[1004]227    message%mpi_requests_in(0:mpi_size-1) = MPI_REQUEST_NULL
228    message%mpi_requests_out(0:mpi_size-1) = MPI_REQUEST_NULL
[999]229    do i = 0, mpi_size-1
230      if(  size(message%mpi_buffer_in(i)%buff) /= message%mpi_buffer_in(i)%n &
231      .or. size(message%mpi_buffer_out(i)%buff) /= message%mpi_buffer_out(i)%n)&
232        call dynamico_abort("Internal error in transfert_mpi : mpi buffer size different than expected")
[1004]233      if( message%mpi_buffer_out(i)%n > 0) then
234        call MPI_Send_Init( message%mpi_buffer_out(i)%buff, message%mpi_buffer_out(i)%n, MPI_REAL8, i,&
235                            100, comm_icosa, message%mpi_requests_out(i), ierr )
236      endif
237      if( message%mpi_buffer_in(i)%n > 0) then
238        call MPI_Recv_Init( message%mpi_buffer_in(i)%buff, message%mpi_buffer_in(i)%n, MPI_REAL8, i,&
239                            100, comm_icosa, message%mpi_requests_in(i), ierr )
240      endif
[999]241    end do
[1003]242   
243    allocate(message%message_in_compact)   
244    message%message_in_compact%npoints = sum(message%message_in(:)%npoints)
245    npoints = message%message_in_compact%npoints
246    allocate(message%message_in_compact%field_ind(npoints))
247    allocate(message%message_in_compact%field_displ(npoints))
248    allocate(message%message_in_compact%sign(npoints))
249    allocate(message%message_in_compact%remote_rank(npoints))
250    allocate(message%message_in_compact%mpi_displ(npoints))
251    allocate(message%message_in_compact%level_offset(npoints))
252
253    last_point=0
254    do i = 1, size( message%message_in )
255      do k = 1, message%message_in(i)%npoints
256        last_point = last_point+1
257        message%message_in_compact%field_ind(last_point)   = message%message_in(i)%ind_loc
258        message%message_in_compact%field_displ(last_point) = message%message_in(i)%displs(k)
259        message%message_in_compact%sign(last_point)        = message%message_in(i)%sign(k)
260        message%message_in_compact%remote_rank(last_point) = message%message_in(i)%remote_rank
261        message%message_in_compact%mpi_displ(last_point)   = message%message_in(i)%mpi_buffer_displ + k
262        message%message_in_compact%level_offset(last_point)= message%message_in(i)%npoints
263      end do
264    end do
265   
266    allocate(message%message_out_compact) 
267    message%message_out_compact%npoints = sum(message%message_out(:)%npoints)
268    npoints = message%message_out_compact%npoints
269    allocate(message%message_out_compact%field_ind(npoints))
270    allocate(message%message_out_compact%field_displ(npoints))
271    allocate(message%message_out_compact%sign(npoints))
272    allocate(message%message_out_compact%remote_rank(npoints))
273    allocate(message%message_out_compact%mpi_displ(npoints))
274    allocate(message%message_out_compact%level_offset(npoints))
275
276    last_point=0
277    do i = 1, size( message%message_out )
278      do k = 1, message%message_out(i)%npoints
279        last_point = last_point+1
280        message%message_out_compact%field_ind(last_point)   = message%message_out(i)%ind_loc
281        message%message_out_compact%field_displ(last_point) = message%message_out(i)%displs(k)
282        message%message_out_compact%sign(last_point)        = message%message_out(i)%sign(k)
283        message%message_out_compact%remote_rank(last_point) = message%message_out(i)%remote_rank
284        message%message_out_compact%mpi_displ(last_point)   = message%message_out(i)%mpi_buffer_displ + k
285        message%message_out_compact%level_offset(last_point)= message%message_out(i)%npoints
286      end do
287    end do
288   
289    allocate(message%message_local_compact) 
290    message%message_local_compact%npoints = sum(message%message_local(:)%npoints)
291    npoints = message%message_local_compact%npoints
292    allocate(message%message_local_compact%field_ind_src(npoints))
293    allocate(message%message_local_compact%field_displ_src(npoints))
294    allocate(message%message_local_compact%sign(npoints))
295    allocate(message%message_local_compact%field_ind_dest(npoints))
296    allocate(message%message_local_compact%field_displ_dest(npoints))
297
298    last_point=0
299    do i = 1, size( message%message_local )
300      do k = 1, message%message_local(i)%npoints
301        last_point = last_point+1
302        message%message_local_compact%field_ind_src(last_point)   = message%message_local(i)%src_ind_loc
303        message%message_local_compact%field_displ_src(last_point) = message%message_local(i)%displ_src(k)
304        message%message_local_compact%sign(last_point)            = message%message_local(i)%sign(k)
305        message%message_local_compact%field_ind_dest(last_point)  = message%message_local(i)%dest_ind_loc
306        message%message_local_compact%field_displ_dest(last_point)= message%message_local(i)%displ_dest(k)
307      end do
308    end do 
309   
[963]310    !$omp end master
311    !$omp barrier
312  contains
313    ! Generate submessage from points
[965]314    function make_submessage(field_type, points, ind_loc, remote_ind_glo, dim3, dim4, vector) result(submessage)
315      use dimensions, only : swap_dimensions, iim, u_pos, z_pos
316      integer, intent(in) :: field_type
[963]317      type(t_points), intent(in) :: points
318      integer, intent(in) :: ind_loc, remote_ind_glo, dim3, dim4
319      logical, intent(in) :: vector
320      integer :: k
321      type(t_submessage) :: submessage
[186]322
[963]323      call swap_dimensions(ind_loc)
324      submessage%ind_loc = ind_loc
325      submessage%remote_ind_glo = remote_ind_glo
[999]326      submessage%remote_rank = domglo_rank(remote_ind_glo)
[998]327      submessage%npoints = points%npoints
[999]328      submessage%mpi_buffer_displ = -1 ! Buffers not allocated yet
[963]329      allocate( submessage%displs( points%npoints ) )
330      submessage%displs(:) = points%i + (points%j-1)*iim
[965]331      if(field_type == field_U) submessage%displs = submessage%displs + u_pos( points%elt )
332      if(field_type == field_Z) submessage%displs = submessage%displs + z_pos( points%elt )
[963]333      allocate(submessage%sign( points%npoints ))
[965]334      if( vector ) then ! For U fields only
335        submessage%sign(:) = (/( domain(ind_loc)%edge_assign_sign(points%elt(k)-1, points%i(k), points%j(k)) ,k=1,points%npoints)/)
[963]336      else
337        submessage%sign(:) = 1
338      endif
339    end function
[186]340
[963]341    ! Add element to array, and reallocate if necessary
342    subroutine array_append_submessage( a, a_size, elt )
343      type(t_submessage), allocatable, intent(inout) :: a(:)
344      integer, intent(inout) :: a_size
345      type(t_submessage), intent(in) :: elt
346      type(t_submessage), allocatable :: a_tmp(:)
347      integer, parameter :: GROW_FACTOR = 2
[364]348
[963]349      if( size( a ) <= a_size ) then
350        allocate( a_tmp ( a_size * GROW_FACTOR ) )
351        a_tmp(1:a_size) = a(1:a_size)
352        call move_alloc(a_tmp, a)
353      end if
354      a_size = a_size + 1
355      a(a_size) = elt;
356    end subroutine
357    ! Add element to array, and reallocate if necessary
358    subroutine array_append_local_submessage( a, a_size, elt )
359      type(t_local_submessage), allocatable, intent(inout) :: a(:)
360      integer, intent(inout) :: a_size
361      type(t_local_submessage), intent(in) :: elt
362      type(t_local_submessage), allocatable :: a_tmp(:)
363      integer, parameter :: GROW_FACTOR = 2
[186]364
[963]365      if( size( a ) <= a_size ) then
366        allocate( a_tmp ( a_size * GROW_FACTOR ) )
367        a_tmp(1:a_size) = a(1:a_size)
368        call move_alloc(a_tmp, a)
369      end if
370      a_size = a_size + 1
371      a(a_size) = elt;
372    end subroutine
373    ! Je demande pardon au dieu du copier-coller car j'ai péché
374  end subroutine
[186]375
[963]376  subroutine message_create_ondevice(message)
377    use mpi_mod
[999]378    use mpipara, only : mpi_size, comm_icosa
[963]379    type(t_message), intent(inout) :: message
380    integer :: i, ierr
[151]381
[963]382    if( message%ondevice ) call dynamico_abort("Message already on device")
[186]383
[963]384    !$acc enter data copyin(message) async
[999]385    !$acc enter data copyin(message%mpi_buffer_in(:)) async
386    !$acc enter data copyin(message%mpi_buffer_out(:)) async
387    do i = 0, mpi_size-1
388      !$acc enter data copyin(message%mpi_buffer_in(i)%buff(:)) async
389      !$acc enter data copyin(message%mpi_buffer_out(i)%buff(:)) async
390    end do
[1003]391    !!$acc enter data copyin(message%message_in(:)) async
392    !do i = 1, size( message%message_in )
393    !  !$acc enter data copyin(message%message_in(i)%displs(:)) async
394    !  !$acc enter data copyin(message%message_in(i)%sign(:)) async
395    !end do
396    !!$acc enter data copyin(message%message_out(:)) async
397    !do i = 1, size( message%message_out )
398    !  !$acc enter data copyin(message%message_out(i)%displs(:)) async
[963]399      !!$acc enter data copyin(message%message_out(i)%sign(:)) async
[1003]400    !end do
401    !!$acc enter data copyin(message%message_local(:)) async
402    !do i = 1, size( message%message_local )
403    !  !$acc enter data copyin(message%message_local(i)%displ_src(:)) async
404    !  !$acc enter data copyin(message%message_local(i)%displ_dest(:)) async
405    !  !$acc enter data copyin(message%message_local(i)%sign(:)) async
406    !end do
[963]407    !$acc enter data copyin(message%field(:)) async
408    do i = 1, ndomain
409      !$acc enter data copyin(message%field(i)%rval4d(:,:,:)) async
410    end do
[1003]411   
412    !$acc enter data copyin(message%message_in_compact) async
413    !$acc enter data copyin(message%message_in_compact%field_ind(:)) async
414    !$acc enter data copyin(message%message_in_compact%field_displ(:)) async
415    !$acc enter data copyin(message%message_in_compact%sign(:)) async
416    !$acc enter data copyin(message%message_in_compact%remote_rank(:)) async
417    !$acc enter data copyin(message%message_in_compact%mpi_displ(:)) async
418    !$acc enter data copyin(message%message_in_compact%level_offset(:)) async
419   
420    !$acc enter data copyin(message%message_out_compact) async
421    !$acc enter data copyin(message%message_out_compact%field_ind(:)) async
422    !$acc enter data copyin(message%message_out_compact%field_displ(:)) async
423    !$acc enter data copyin(message%message_out_compact%sign(:)) async
424    !$acc enter data copyin(message%message_out_compact%remote_rank(:)) async
425    !$acc enter data copyin(message%message_out_compact%mpi_displ(:)) async
426    !$acc enter data copyin(message%message_out_compact%level_offset(:)) async
427   
428    !$acc enter data copyin(message%message_local_compact) async
429    !$acc enter data copyin(message%message_local_compact%field_ind_src(:)) async
430    !$acc enter data copyin(message%message_local_compact%field_displ_src(:)) async
431    !$acc enter data copyin(message%message_local_compact%sign(:)) async
432    !$acc enter data copyin(message%message_local_compact%field_ind_dest(:)) async
433    !$acc enter data copyin(message%message_local_compact%field_displ_dest(:)) async
[186]434
[999]435    !$acc wait
436    do i = 0, mpi_size-1
[1004]437      if( message%mpi_requests_out(i) /= MPI_REQUEST_NULL ) then
438        call MPI_Request_free(message%mpi_requests_out(i), ierr)
439        !$acc host_data use_device(message%mpi_buffer_out(i)%buff)
440          ! /!\ buff(1) is important for PGI to avoid temporary array copy
441          call MPI_Send_Init( message%mpi_buffer_out(i)%buff(1), message%mpi_buffer_out(i)%n, MPI_REAL8, i,&
442                              0, comm_icosa, message%mpi_requests_out(i), ierr )
443        !$acc end host_data
444      end if
445      if( message%mpi_requests_in(i) /= MPI_REQUEST_NULL ) then
446        call MPI_Request_free(message%mpi_requests_in(i), ierr)
447        !$acc host_data use_device(message%mpi_buffer_in(i)%buff)
448          call MPI_Recv_Init( message%mpi_buffer_in(i)%buff(1), message%mpi_buffer_in(i)%n, MPI_REAL8, i,&
449                              0, comm_icosa, message%mpi_requests_in(i), ierr )
450        !$acc end host_data
451      endif
[963]452    end do
453    message%ondevice=.true.
454    !!$acc update device(message%ondevice)
455  end subroutine
[186]456
[963]457  subroutine message_delete_ondevice(message)
[999]458    use mpipara, only : mpi_size
[963]459    type(t_message), intent(inout) :: message
460    integer :: i
[186]461
[963]462    if( .not. message%ondevice ) call dynamico_abort("Message not on device")
[186]463
[1003]464    !do i = 1, size( message%message_in )
465    !  !$acc exit data delete(message%message_in(i)%displs(:)) async
466    !  !$acc exit data delete(message%message_in(i)%sign(:)) async
467    !end do
468    !!$acc exit data delete(message%message_in(:)) async
469    !do i = 1, size( message%message_out )
470    !  !$acc exit data delete(message%message_out(i)%displs(:)) async
471    !  !!$acc exit data delete(message%message_out(i)%sign(:)) async
472    !end do
473    !!$acc exit data delete(message%message_out(:)) async
474    !do i = 1, size( message%message_local )
475    !  !$acc exit data delete(message%message_local(i)%displ_src(:)) async
476    !  !$acc exit data delete(message%message_local(i)%displ_dest(:)) async
477    !  !$acc exit data delete(message%message_local(i)%sign(:)) async
478    !end do
479    !!$acc exit data delete(message%message_local(:)) async
[999]480    do i = 0, mpi_size-1
481      !$acc exit data delete(message%mpi_buffer_in(i)%buff(:)) async
482      !$acc exit data delete(message%mpi_buffer_out(i)%buff(:)) async
483    end do
484    !$acc exit data delete(message%mpi_buffer_in(:)) async
485    !$acc exit data delete(message%mpi_buffer_out(:)) async
[963]486    do i = 1, ndomain
487      !$acc exit data delete(message%field(i)%rval4d(:,:,:)) async
488    end do
489    !$acc exit data delete(message%field(:)) async
490    !$acc exit data delete(message) async
[1003]491   
492    !$acc exit data delete(message%message_in_compact%field_ind(:)) async
493    !$acc exit data delete(message%message_in_compact%field_displ(:)) async
494    !$acc exit data delete(message%message_in_compact%sign(:)) async
495    !$acc exit data delete(message%message_in_compact%remote_rank(:)) async
496    !$acc exit data delete(message%message_in_compact%mpi_displ(:)) async
497    !$acc exit data delete(message%message_in_compact%level_offset(:)) async
498    !$acc exit data delete(message%message_in_compact) async
499   
500    !$acc exit data delete(message%message_out_compact%field_ind(:)) async
501    !$acc exit data delete(message%message_out_compact%field_displ(:)) async
502    !$acc exit data delete(message%message_out_compact%sign(:)) async
503    !$acc exit data delete(message%message_out_compact%remote_rank(:)) async
504    !$acc exit data delete(message%message_out_compact%mpi_displ(:)) async
505    !$acc exit data delete(message%message_out_compact%level_offset(:)) async
506    !$acc exit data delete(message%message_out_compact) async
507   
508    !$acc exit data delete(message%message_local_compact%field_ind_src(:)) async
509    !$acc exit data delete(message%message_local_compact%field_displ_src(:)) async
510    !$acc exit data delete(message%message_local_compact%sign(:)) async
511    !$acc exit data delete(message%message_local_compact%field_ind_dest(:)) async
512    !$acc exit data delete(message%message_local_compact%field_displ_dest(:)) async
513    !$acc exit data delete(message%message_local_compact) async
514   
[963]515    message%ondevice=.false.
516  end subroutine
[953]517
[963]518  subroutine finalize_message(message)
[1004]519    use mpi_mod   
[999]520    use mpipara, only : mpi_size
[963]521    type(t_message), intent(inout) :: message
522    integer :: i, ierr
[953]523
[963]524    !$omp barrier
525    !$omp master
526    if(message%send_seq /= message%wait_seq) call dynamico_abort("No matching wait_message before finalization")
[186]527
[963]528    if(message%ondevice) call message_delete_ondevice(message)
529    deallocate(message%message_in)
530    deallocate(message%message_out)
531    deallocate(message%message_local)
[999]532    do i=0, mpi_size-1
[1004]533      if(message%mpi_requests_in(i) /= MPI_REQUEST_NULL) call MPI_Request_free(message%mpi_requests_in(i), ierr)
534      if(message%mpi_requests_out(i) /= MPI_REQUEST_NULL)call MPI_Request_free(message%mpi_requests_out(i), ierr)
[999]535      deallocate(message%mpi_buffer_in(i)%buff)
536      deallocate(message%mpi_buffer_out(i)%buff)
[953]537    end do
[999]538    deallocate(message%mpi_buffer_in)
539    deallocate(message%mpi_buffer_out)
[963]540    deallocate(message%mpi_requests_in)
541    deallocate(message%mpi_requests_out)
[1003]542    deallocate(message%message_in_compact)
543    deallocate(message%message_out_compact)
544    deallocate(message%message_local_compact)
[963]545    !$omp end master
546    !$omp barrier
547  end subroutine
[953]548
[1002]549  ! Halo to Buffer : copy outbound message to MPI buffers
550  subroutine copy_HtoB(message)
[963]551    use domain_mod, only : assigned_domain
552    use omp_para, only : distrib_level
[1002]553    type(t_message), intent(inout) :: message
554    integer :: dim3, dim4, d3_begin, d3_end
555    integer :: k, d3, d4, i
556    integer :: local_displ
[151]557
[963]558    dim4 = size(message%field(1)%rval4d, 3)
[999]559    dim3 = size(message%field(1)%rval4d, 2)
560    CALL distrib_level( 1, dim3, d3_begin, d3_end )
[1002]561   
[1003]562    !$acc parallel loop collapse(3) present(message) default(present) async if(message%ondevice)
563    do d4 = 1, dim4
564      do d3 = d3_begin, d3_end
565        do i=1, message%message_out_compact%npoints
566          message%mpi_buffer_out( message%message_out_compact%remote_rank(i) )%buff( message%message_out_compact%mpi_displ(i) + message%message_out_compact%level_offset(i)*( (d3-1) + dim3*(d4-1) ) ) &
567            = message%field(message%message_out_compact%field_ind(i))%rval4d( message%message_out_compact%field_displ(i), d3, d4 )
[963]568        end do
[1003]569      end do
[963]570    end do
[1003]571   
[1002]572  end subroutine
[151]573
[1002]574  ! Halo to Halo : copy local messages from source field to destination field
575  subroutine copy_HtoH(message)
576    use domain_mod, only : assigned_domain
577    use omp_para, only : distrib_level
578    type(t_message), intent(inout) :: message
579    integer :: dim3, dim4, d3_begin, d3_end
580    integer :: k, d3, d4, i
581
582    dim4 = size(message%field(1)%rval4d, 3)
583    dim3 = size(message%field(1)%rval4d, 2)
584    CALL distrib_level( 1, dim3, d3_begin, d3_end )
[1003]585   
586    ! TODO : too many copies when tiles are distributed among threads
587    !$acc parallel loop collapse(3) present(message) default(present) async if(message%ondevice)
588    do d4 = 1, dim4
589      do d3 = d3_begin, d3_end
590        do i=1, message%message_local_compact%npoints
591          message%field(message%message_local_compact%field_ind_dest(i))%rval4d( message%message_local_compact%field_displ_dest(i), d3, d4 ) &
592            = message%message_local_compact%sign(i)*message%field(message%message_local_compact%field_ind_src(i))%rval4d( message%message_local_compact%field_displ_src(i), d3, d4 )
[963]593        end do
[1003]594      end do
[963]595    end do
[1002]596  end subroutine
[186]597
[1002]598  ! Buffer to Halo : copy inbound message to field
599  subroutine copy_BtoH(message)
600    use domain_mod, only : assigned_domain
601    use omp_para, only : distrib_level
602    type(t_message), intent(inout) :: message
603    integer :: dim3, dim4, d3_begin, d3_end
604    integer :: k, d3, d4, i
[1003]605    integer :: last_point
606   
[1002]607    dim4 = size(message%field(1)%rval4d, 3)
608    dim3 = size(message%field(1)%rval4d, 2)
609    CALL distrib_level( 1, dim3, d3_begin, d3_end )
610   
[1003]611    !$acc parallel loop collapse(3) present(message) default(present) async if(message%ondevice)
612    do d4 = 1, dim4
613      do d3 = d3_begin, d3_end
614        do i=1, message%message_in_compact%npoints
615          message%field(message%message_in_compact%field_ind(i))%rval4d( message%message_in_compact%field_displ(i), d3, d4 ) &
616            = message%message_in_compact%sign(i)*message%mpi_buffer_in( message%message_in_compact%remote_rank(i) )%buff( message%message_in_compact%mpi_displ(i) + message%message_in_compact%level_offset(i)*( (d3-1) + dim3*(d4-1) ) ) 
[1002]617        end do
[1003]618      end do
[1002]619    end do
[1003]620
[1002]621  end subroutine
[1003]622   
[962]623
[1002]624  subroutine send_message(field, message)
625    use mpi_mod
[1004]626    use mpipara , only : mpi_size
[1002]627    type(t_field),pointer :: field(:)
628    type(t_message), target :: message
[1004]629    integer :: ierr, i
[1002]630
631    call enter_profile(profile_mpi)
632
633    ! Needed because rval4d is set in init_message
634    if( .not. associated( message%field, field ) ) &
635      call dynamico_abort("send_message must be called with the same field used in init_message")
636
637    !Prepare 'message' for on-device copies if field is on device
638    !$omp master
639    if( field(1)%ondevice .and. .not. message%ondevice ) call message_create_ondevice(message)
640    if( field(1)%ondevice .neqv. message%ondevice ) call dynamico_abort("send_message : internal device/host memory synchronization error")
641    ! Check if previous message has been waited
642    if(message%send_seq /= message%wait_seq) &
643      call dynamico_abort("No matching wait_message before new send_message")
644    message%send_seq = message%send_seq + 1
645    !$omp end master
646
[963]647    call enter_profile(profile_mpi_barrier)
[1002]648    !$omp barrier
649    call exit_profile(profile_mpi_barrier)
650
651    call enter_profile(profile_mpi_copies)
652    call copy_HtoB(message)
653    call exit_profile(profile_mpi_copies)
654
[1018]655    call enter_profile(profile_mpi_barrier)
656    !$acc wait
657    call exit_profile(profile_mpi_barrier)
658   
[1004]659    !$omp master
660    do i=0, mpi_size-1
661      if(message%mpi_requests_in(i) /= MPI_REQUEST_NULL) call MPI_Start( message%mpi_requests_in(i), ierr )
662    end do
663    !$omp end master
664
[1002]665    call enter_profile(profile_mpi_barrier)
[963]666    !$omp barrier
667    call exit_profile(profile_mpi_barrier)
[478]668
[963]669    !$omp master
[1004]670    do i=0, mpi_size-1
671      if(message%mpi_requests_out(i) /= MPI_REQUEST_NULL) call MPI_Start( message%mpi_requests_out(i), ierr )
672    end do
[963]673    !$omp end master
[962]674
[1004]675    call enter_profile(profile_mpi_copies)
676    call copy_HtoH(message)
677    call exit_profile(profile_mpi_copies)
678
[963]679    call exit_profile(profile_mpi)
680  end subroutine
[962]681
[963]682  subroutine test_message(message)
683    use mpi_mod
684    type(t_message) :: message
685    integer :: ierr
686    logical :: completed
[478]687
[1004]688    !!$omp master
689    !call MPI_Testall( size(message%mpi_requests_out), message%mpi_requests_out, completed, MPI_STATUSES_IGNORE, ierr )
690    !call MPI_Testall( size(message%mpi_requests_in), message%mpi_requests_in, completed, MPI_STATUSES_IGNORE, ierr )
691    !!$omp end master
[963]692  end subroutine
[186]693
[963]694  subroutine wait_message(message)
695    use mpi_mod
696    type(t_message), target :: message
[1002]697    integer :: ierr
[186]698
[963]699    ! Check if message has been sent and not recieved yet
700    ! note : barrier needed between this and send_seq increment, and this and wait_seq increment
701    ! note : watch out for integer overflow a = b+1 doesn't imply a>b
702    if(message%send_seq /= message%wait_seq+1) then
703      print*, "WARNING : wait_message called multiple times for one send_message, skipping"
704      return ! Don't recieve message if already recieved
705    end if
[186]706
[963]707    call enter_profile(profile_mpi)
[186]708
[963]709    call enter_profile(profile_mpi_waitall)
710    !$omp master
711    call MPI_Waitall( size(message%mpi_requests_out), message%mpi_requests_out, MPI_STATUSES_IGNORE, ierr )
712    call MPI_Waitall( size(message%mpi_requests_in), message%mpi_requests_in, MPI_STATUSES_IGNORE, ierr )
713    !$omp end master
714    call exit_profile(profile_mpi_waitall)
[186]715
[963]716    call enter_profile(profile_mpi_barrier)
717    !$omp barrier
718    call exit_profile(profile_mpi_barrier)
[667]719
[1002]720    call enter_profile(profile_mpi_copies) 
721    call copy_BtoH(message)
[963]722    call exit_profile(profile_mpi_copies)
[151]723
[963]724    !$omp master
725    message%wait_seq = message%wait_seq + 1
726    !$omp end master
[151]727
[963]728    call enter_profile(profile_mpi_barrier)
729    !$omp barrier
730    call exit_profile(profile_mpi_barrier)
[1002]731
[963]732    call exit_profile(profile_mpi)
733  end subroutine
734end module
Note: See TracBrowser for help on using the repository browser.