Skip to content

Commit 4845ce2

Browse files
committed
Consolidating OpenACC device-host memory transfers
This PR consolidates much of the OpenACC host and device data transfers during the course of the dynamical execution to two subroutines mpas_atm_pre_dynamics _h2d and mpas_atm_post_dynamics_d2h that are called before and after the call to atm_srk3 subroutine. Due to atm_compute_solve_diagnostics also being called once before the start of model run, we also have a pair of subroutines mpas_atm _pre_computesolvediag_h2d and mpas_atm_post_computesolvediag_d2h to handle data movements around the first call to atm_compute_solve_diagnostics. Any fields copied onto the device in these subroutines are removed from explicit data movement statements in the dynamical core. The mesh/time-invariant fields are still copied onto the device in mpas_atm_ dynamics_init and removed from the device in mpas_atm_dynamics_finalize, with the exception of select fields moved in mpas_atm_pre_computesolvediag_h2d and mpas_atm_post_computesolvediag_d2h. This is a special case due to atm_compute_ solve_diagnostics being called for the first time before the call to mpas_atm_ dynamics_init This PR also includes explicit host-device data transfers in the mpas_atm_iau, mpas_atmphys_interface and mpas_atmphys_todynamics modules to ensure that the physics and IAU regions, which run on CPU, use the latest values from the dynamical core running on GPUs, and vice versa. In addition, this PR also includes explicit data transfers around halo exchanges in the atm_srk3 subroutine. These subroutines for data routines, and the acc update statements are an interim solution until we have a book-keeping method in place. This PR also introduces a couple of new timers to keep track of the cost of data transfers.
1 parent 956b020 commit 4845ce2

File tree

6 files changed

+1071
-376
lines changed

6 files changed

+1071
-376
lines changed

src/core_atmosphere/dynamics/mpas_atm_boundaries.F

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -310,18 +310,14 @@ subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t
310310
nullify(tend)
311311
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)
312312

313-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_tend [ACC_data_xfer]')
314313
if (associated(tend)) then
315-
!$acc enter data copyin(tend)
316314
else
317315
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
318-
!$acc enter data copyin(tend_scalars)
319316

320317
! Ensure the integer pointed to by idx_ptr is copied to the gpu device
321318
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)
322319
idx = idx_ptr
323320
end if
324-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_tend [ACC_data_xfer]')
325321

326322
!$acc parallel default(present)
327323
if (associated(tend)) then
@@ -341,13 +337,6 @@ subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t
341337
end if
342338
!$acc end parallel
343339

344-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_tend [ACC_data_xfer]')
345-
if (associated(tend)) then
346-
!$acc exit data delete(tend)
347-
else
348-
!$acc exit data delete(tend_scalars)
349-
end if
350-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_tend [ACC_data_xfer]')
351340

352341
end subroutine mpas_atm_get_bdy_tend
353342

@@ -448,9 +437,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del
448437
! query the field as a scalar constituent
449438
!
450439
if (associated(tend) .and. associated(state)) then
451-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
452-
!$acc enter data copyin(tend, state)
453-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
454440

455441
!$acc parallel default(present)
456442
!$acc loop gang vector collapse(2)
@@ -461,20 +447,13 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del
461447
end do
462448
!$acc end parallel
463449

464-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
465-
!$acc exit data delete(tend, state)
466-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
467450
else
468451
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
469452
call mpas_pool_get_array(lbc, 'lbc_scalars', state_scalars, 2)
470453
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)
471454

472455
idx=idx_ptr ! Avoid non-array pointer for OpenACC
473456

474-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
475-
!$acc enter data copyin(tend_scalars, state_scalars)
476-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
477-
478457
!$acc parallel default(present)
479458
!$acc loop gang vector collapse(2)
480459
do i=1, horizDim+1
@@ -484,9 +463,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del
484463
end do
485464
!$acc end parallel
486465

487-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
488-
!$acc exit data delete(tend_scalars, state_scalars)
489-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
490466
end if
491467

492468
end subroutine mpas_atm_get_bdy_state_2d
@@ -567,10 +543,6 @@ subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim,
567543
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)
568544
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), state, 2)
569545

570-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
571-
!$acc enter data copyin(tend, state)
572-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
573-
574546
!$acc parallel default(present)
575547
!$acc loop gang vector collapse(3)
576548
do i=1, horizDim+1
@@ -582,10 +554,6 @@ subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim,
582554
end do
583555
!$acc end parallel
584556

585-
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
586-
!$acc exit data delete(tend, state)
587-
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
588-
589557
end subroutine mpas_atm_get_bdy_state_3d
590558

591559

src/core_atmosphere/dynamics/mpas_atm_iau.F

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,20 @@ module mpas_atm_iau
1313
use mpas_dmpar
1414
use mpas_constants
1515
use mpas_log, only : mpas_log_write
16+
use mpas_timer
1617

1718
!public :: atm_compute_iau_coef, atm_add_tend_anal_incr
1819

20+
21+
#ifdef MPAS_OPENACC
22+
#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X)
23+
#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X)
24+
#else
25+
#define MPAS_ACC_TIMER_START(X)
26+
#define MPAS_ACC_TIMER_STOP(X)
27+
#endif
28+
29+
1930
contains
2031

2132
!==================================================================================================
@@ -137,6 +148,7 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten
137148
call mpas_pool_get_array(state, 'scalars', scalars, 1)
138149
call mpas_pool_get_array(state, 'rho_zz', rho_zz, 2)
139150
call mpas_pool_get_array(diag , 'rho_edge', rho_edge)
151+
!$acc update self(theta_m, scalars, rho_zz, rho_edge)
140152

141153
call mpas_pool_get_dimension(state, 'moist_start', moist_start)
142154
call mpas_pool_get_dimension(state, 'moist_end', moist_end)
@@ -149,6 +161,8 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten
149161
! call mpas_pool_get_array(tend, 'rho_zz', tend_rho)
150162
! call mpas_pool_get_array(tend, 'theta_m', tend_theta)
151163
call mpas_pool_get_array(tend, 'scalars_tend', tend_scalars)
164+
!$acc update self(tend_scalars)
165+
MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer')
152166

153167
call mpas_pool_get_array(tend_iau, 'theta', theta_amb)
154168
call mpas_pool_get_array(tend_iau, 'rho', rho_amb)

0 commit comments

Comments
 (0)