MESH: Rewrite prolong_block() in order to reduce its memory usage.

Instead of using a temporary array for all variables, we now reuse the
same array for the prolongation of each variable.

Signed-off-by: Grzegorz Kowal <grzegorz@amuncode.org>
This commit is contained in:
Grzegorz Kowal 2021-11-11 16:11:10 -03:00
parent 28d2dd061b
commit c0c549453d

View File

@ -53,7 +53,7 @@ module mesh
! allocatable array for prolongation ! allocatable array for prolongation
! !
real(kind=8), dimension(:,:,:,:), allocatable :: up real(kind=8), dimension(:,:,:), allocatable :: work
! by default everything is private ! by default everything is private
! !
@ -94,7 +94,6 @@ module mesh
! import external procedures and variables ! import external procedures and variables
! !
use coordinates , only : ni => ncells, ng => nghosts, toplev use coordinates , only : ni => ncells, ng => nghosts, toplev
use equations , only : nv
use iso_fortran_env, only : error_unit use iso_fortran_env, only : error_unit
use mpitools , only : master, nprocs use mpitools , only : master, nprocs
@ -188,7 +187,7 @@ module mesh
! allocate array for the prolongation array ! allocate array for the prolongation array
! !
allocate(up(nv, pm(1), pm(2), pm(3)), stat = status) allocate(work(pm(1), pm(2), pm(3)), stat = status)
if (status /= 0) then if (status /= 0) then
write(error_unit,"('[',a,']: ',a)") trim(loc) & write(error_unit,"('[',a,']: ',a)") trim(loc) &
@ -255,8 +254,8 @@ module mesh
! deallocate prolongation array ! deallocate prolongation array
! !
if (allocated(up)) then if (allocated(work)) then
deallocate(up, stat = status) deallocate(work, stat = status)
if (status /= 0) then if (status /= 0) then
write(error_unit,"('[',a,']: ',a)") trim(loc) & write(error_unit,"('[',a,']: ',a)") trim(loc) &
, "Prolongation array could not be deallocated!" , "Prolongation array could not be deallocated!"
@ -1136,21 +1135,25 @@ module mesh
type(block_meta), pointer, intent(inout) :: pmeta type(block_meta), pointer, intent(inout) :: pmeta
integer , intent(out) :: status integer , intent(out) :: status
integer :: p
integer :: i, il, iu, ic, ip, im1, ip1
integer :: j, jl, ju, jc, jp, jm1, jp1
integer :: k, kc
#if NDIMS == 3
integer :: kl, ku, kp, km1, kp1
#endif /* NDIMS == 3 */
real(kind=8) :: dul, dur, du1, du2
#if NDIMS == 3
real(kind=8) :: du3, du4
#endif /* NDIMS == 3 */
type(block_meta), pointer :: pchild type(block_meta), pointer :: pchild
type(block_data), pointer :: pdata type(block_data), pointer :: pdata
integer :: n, p, nl, nu
integer :: i, im, ip
integer :: j, jm, jp
#if NDIMS == 3
integer :: k, km, kp
#else /* NDIMS == 3 */
integer :: k
#endif /* NDIMS == 3 */
real(kind=8) :: dul, dur
#if NDIMS == 3
real(kind=8) :: du1, du2, du3, du4
#else /* NDIMS == 3 */
real(kind=8) :: du1, du2
#endif /* NDIMS == 3 */
integer , dimension(NDIMS) :: l, u
real(kind=8), dimension(NDIMS) :: du real(kind=8), dimension(NDIMS) :: du
character(len=*), parameter :: loc = 'MESH::prolong_block()' character(len=*), parameter :: loc = 'MESH::prolong_block()'
@ -1163,63 +1166,57 @@ module mesh
status = 0 status = 0
pdata => pmeta%data
il = nb - nh
iu = ne + nh
jl = nb - nh
ju = ne + nh
#if NDIMS == 3
kl = nb - nh
ku = ne + nh
#endif /* NDIMS == 3 */
#if NDIMS == 2 #if NDIMS == 2
k = 1 k = 1
kc = 1
#endif /* NDIMS == 2 */ #endif /* NDIMS == 2 */
nl = nb - nh
nu = ne + nh
pdata => pmeta%data
do n = 1, nv
#if NDIMS == 3 #if NDIMS == 3
do k = kl, ku do k = nl, nu
km1 = k - 1 km = k - 1
kp1 = k + 1 kp = k + 1
kc = 2 * (k - kl) + 1 l(3) = 2 * (k - nl) + 1
kp = kc + 1 u(3) = l(3) + 1
#endif /* NDIMS == 3 */ #endif /* NDIMS == 3 */
do j = jl, ju do j = nl, nu
jm1 = j - 1 jm = j - 1
jp1 = j + 1 jp = j + 1
jc = 2 * (j - jl) + 1 l(2) = 2 * (j - nl) + 1
jp = jc + 1 u(2) = l(2) + 1
do i = il, iu do i = nl, nu
im1 = i - 1 im = i - 1
ip1 = i + 1 ip = i + 1
ic = 2 * (i - il) + 1 l(1) = 2 * (i - nl) + 1
ip = ic + 1 u(1) = l(1) + 1
do p = 1, nv dul = pdata%u(n,i ,j,k) - pdata%u(n,im,j,k)
dur = pdata%u(n,ip,j,k) - pdata%u(n,i ,j,k)
dul = pdata%u(p,i ,j,k) - pdata%u(p,im1,j,k)
dur = pdata%u(p,ip1,j,k) - pdata%u(p,i ,j,k)
du(1) = limiter_prol(2.5d-01, dul, dur) du(1) = limiter_prol(2.5d-01, dul, dur)
dul = pdata%u(p,i,j ,k) - pdata%u(p,i,jm1,k) dul = pdata%u(n,i,j ,k) - pdata%u(n,i,jm,k)
dur = pdata%u(p,i,jp1,k) - pdata%u(p,i,j ,k) dur = pdata%u(n,i,jp,k) - pdata%u(n,i,j ,k)
du(2) = limiter_prol(2.5d-01, dul, dur) du(2) = limiter_prol(2.5d-01, dul, dur)
#if NDIMS == 3 #if NDIMS == 3
dul = pdata%u(p,i,j,k ) - pdata%u(p,i,j,km1) dul = pdata%u(n,i,j,k ) - pdata%u(n,i,j,km)
dur = pdata%u(p,i,j,kp1) - pdata%u(p,i,j,k ) dur = pdata%u(n,i,j,kp) - pdata%u(n,i,j,k )
du(3) = limiter_prol(2.5d-01, dul, dur) du(3) = limiter_prol(2.5d-01, dul, dur)
#endif /* NDIMS == 3 */ #endif /* NDIMS == 3 */
if (positive(p) .and. pdata%u(p,i,j,k) < sum(abs(du(1:NDIMS)))) then if (positive(n) .and. pdata%u(n,i,j,k) < sum(abs(du(1:NDIMS)))) then
if (pdata%u(p,i,j,k) > 0.0d+00) then if (pdata%u(n,i,j,k) > 0.0d+00) then
do while (pdata%u(p,i,j,k) <= sum(abs(du(1:NDIMS)))) do while (pdata%u(n,i,j,k) <= sum(abs(du(1:NDIMS))))
du(:) = 0.5d+00 * du(:) du(:) = 0.5d+00 * du(:)
end do end do
else else
write(error_unit,"('[',a,']: ',a,3i4,a)") trim(loc) & write(error_unit,"('[',a,']: ',a,3i4,a)") trim(loc), &
, "Positive variable is not positive at (", i, j, k, " )" "Positive variable is not positive at (", i, j, k, " )"
du(:) = 0.0d+00 du(:) = 0.0d+00
end if end if
end if end if
@ -1227,10 +1224,10 @@ module mesh
#if NDIMS == 2 #if NDIMS == 2
du1 = du(1) + du(2) du1 = du(1) + du(2)
du2 = du(1) - du(2) du2 = du(1) - du(2)
up(p,ic,jc,kc) = pdata%u(p,i,j,k) - du1 work(l(1),l(2),k) = pdata%u(n,i,j,k) - du1
up(p,ip,jc,kc) = pdata%u(p,i,j,k) + du2 work(u(1),l(2),k) = pdata%u(n,i,j,k) + du2
up(p,ic,jp,kc) = pdata%u(p,i,j,k) - du2 work(l(1),u(2),k) = pdata%u(n,i,j,k) - du2
up(p,ip,jp,kc) = pdata%u(p,i,j,k) + du1 work(u(1),u(2),k) = pdata%u(n,i,j,k) + du1
#endif /* NDIMS == 2 */ #endif /* NDIMS == 2 */
#if NDIMS == 3 #if NDIMS == 3
@ -1238,18 +1235,17 @@ module mesh
du2 = du(1) - du(2) - du(3) du2 = du(1) - du(2) - du(3)
du3 = du(1) - du(2) + du(3) du3 = du(1) - du(2) + du(3)
du4 = du(1) + du(2) - du(3) du4 = du(1) + du(2) - du(3)
up(p,ic,jc,kc) = pdata%u(p,i,j,k) - du1 work(l(1),l(2),l(3)) = pdata%u(n,i,j,k) - du1
up(p,ip,jc,kc) = pdata%u(p,i,j,k) + du2 work(u(1),l(2),l(3)) = pdata%u(n,i,j,k) + du2
up(p,ic,jp,kc) = pdata%u(p,i,j,k) - du3 work(l(1),u(2),l(3)) = pdata%u(n,i,j,k) - du3
up(p,ip,jp,kc) = pdata%u(p,i,j,k) + du4 work(u(1),u(2),l(3)) = pdata%u(n,i,j,k) + du4
up(p,ic,jc,kp) = pdata%u(p,i,j,k) - du4 work(l(1),l(2),u(3)) = pdata%u(n,i,j,k) - du4
up(p,ip,jc,kp) = pdata%u(p,i,j,k) + du3 work(u(1),l(2),u(3)) = pdata%u(n,i,j,k) + du3
up(p,ic,jp,kp) = pdata%u(p,i,j,k) - du2 work(l(1),u(2),u(3)) = pdata%u(n,i,j,k) - du2
up(p,ip,jp,kp) = pdata%u(p,i,j,k) + du1 work(u(1),u(2),u(3)) = pdata%u(n,i,j,k) + du1
#endif /* NDIMS == 3 */ #endif /* NDIMS == 3 */
end do end do
end do end do
end do
#if NDIMS == 3 #if NDIMS == 3
end do end do
#endif /* NDIMS == 3 */ #endif /* NDIMS == 3 */
@ -1258,32 +1254,18 @@ module mesh
pchild => pmeta%child(p)%ptr pchild => pmeta%child(p)%ptr
ic = pchild%pos(1) l(1:NDIMS) = pchild%pos(1:NDIMS) * ni + 1
jc = pchild%pos(2) u(1:NDIMS) = l(1:NDIMS) + nn - 1
#if NDIMS == 3
kc = pchild%pos(3)
#endif /* NDIMS == 3 */
il = 1 + ic * ni
jl = 1 + jc * ni
#if NDIMS == 3
kl = 1 + kc * ni
#endif /* NDIMS == 3 */
iu = il + nn - 1
ju = jl + nn - 1
#if NDIMS == 3
ku = kl + nn - 1
#endif /* NDIMS == 3 */
#if NDIMS == 2 #if NDIMS == 2
pchild%data%u(1:nv,:,:,:) = up(1:nv,il:iu,jl:ju, : ) pchild%data%u(n,:,:,1) = work(l(1):u(1),l(2):u(2),k)
#endif /* NDIMS == 2 */ #endif /* NDIMS == 2 */
#if NDIMS == 3 #if NDIMS == 3
pchild%data%u(1:nv,:,:,:) = up(1:nv,il:iu,jl:ju,kl:ku) pchild%data%u(n,:,:,:) = work(l(1):u(1),l(2):u(2),l(3):u(3))
#endif /* NDIMS == 3 */ #endif /* NDIMS == 3 */
end do ! nchildren end do ! nchildren
end do ! n = 1, nv
#ifdef PROFILE #ifdef PROFILE
call stop_timer(imp) call stop_timer(imp)