From c0c549453defef81b01c39af1576d085f3cf276d Mon Sep 17 00:00:00 2001
From: Grzegorz Kowal <grzegorz@amuncode.org>
Date: Thu, 11 Nov 2021 16:11:10 -0300
Subject: [PATCH] MESH: Rewrite prolong_block() in order to reduce its memory
 usage.

Instead of using a temporary array for all variables, we now reuse the
same array for the prolongation of each variable.

Signed-off-by: Grzegorz Kowal <grzegorz@amuncode.org>
---
 sources/mesh.F90 | 168 +++++++++++++++++++++--------------------------
 1 file changed, 75 insertions(+), 93 deletions(-)

diff --git a/sources/mesh.F90 b/sources/mesh.F90
index 09df371..391aa81 100644
--- a/sources/mesh.F90
+++ b/sources/mesh.F90
@@ -53,7 +53,7 @@ module mesh
 
 ! allocatable array for prolongation
 !
-  real(kind=8), dimension(:,:,:,:), allocatable :: up
+  real(kind=8), dimension(:,:,:), allocatable :: work
 
 ! by default everything is private
 !
@@ -94,7 +94,6 @@ module mesh
 ! import external procedures and variables
 !
     use coordinates    , only : ni => ncells, ng => nghosts, toplev
-    use equations      , only : nv
     use iso_fortran_env, only : error_unit
     use mpitools       , only : master, nprocs
 
@@ -188,7 +187,7 @@ module mesh
 
 ! allocate array for the prolongation array
 !
-    allocate(up(nv, pm(1), pm(2), pm(3)), stat = status)
+    allocate(work(pm(1), pm(2), pm(3)), stat = status)
 
     if (status /= 0) then
       write(error_unit,"('[',a,']: ',a)") trim(loc)                            &
@@ -255,8 +254,8 @@ module mesh
 
 ! deallocate prolongation array
 !
-    if (allocated(up)) then
-      deallocate(up, stat = status)
+    if (allocated(work)) then
+      deallocate(work, stat = status)
       if (status /= 0) then
         write(error_unit,"('[',a,']: ',a)") trim(loc)                          &
                         , "Prolongation array could not be deallocated!"
@@ -1136,21 +1135,25 @@ module mesh
     type(block_meta), pointer, intent(inout) :: pmeta
     integer                  , intent(out)   :: status
 
-    integer      :: p
-    integer      :: i, il, iu, ic, ip, im1, ip1
-    integer      :: j, jl, ju, jc, jp, jm1, jp1
-    integer      :: k, kc
-#if NDIMS == 3
-    integer      :: kl, ku, kp, km1, kp1
-#endif /* NDIMS == 3 */
-    real(kind=8) :: dul, dur, du1, du2
-#if NDIMS == 3
-    real(kind=8) :: du3, du4
-#endif /* NDIMS == 3 */
-
     type(block_meta), pointer :: pchild
     type(block_data), pointer :: pdata
 
+    integer      :: n, p, nl, nu
+    integer      :: i, im, ip
+    integer      :: j, jm, jp
+#if NDIMS == 3
+    integer      :: k, km, kp
+#else /* NDIMS == 3 */
+    integer      :: k
+#endif /* NDIMS == 3 */
+    real(kind=8) :: dul, dur
+#if NDIMS == 3
+    real(kind=8) :: du1, du2, du3, du4
+#else /* NDIMS == 3 */
+    real(kind=8) :: du1, du2
+#endif /* NDIMS == 3 */
+
+    integer     , dimension(NDIMS) :: l, u
     real(kind=8), dimension(NDIMS) :: du
 
     character(len=*), parameter :: loc = 'MESH::prolong_block()'
@@ -1163,63 +1166,57 @@ module mesh
 
     status = 0
 
-    pdata => pmeta%data
-
-    il = nb - nh
-    iu = ne + nh
-    jl = nb - nh
-    ju = ne + nh
-#if NDIMS == 3
-    kl = nb - nh
-    ku = ne + nh
-#endif /* NDIMS == 3 */
-
 #if NDIMS == 2
     k  = 1
-    kc = 1
 #endif /* NDIMS == 2 */
+
+    nl = nb - nh
+    nu = ne + nh
+
+    pdata => pmeta%data
+
+    do n = 1, nv
+
 #if NDIMS == 3
-    do k = kl, ku
-      km1 = k - 1
-      kp1 = k + 1
-      kc  = 2 * (k - kl) + 1
-      kp  = kc + 1
+      do k = nl, nu
+        km   = k - 1
+        kp   = k + 1
+        l(3) = 2 * (k - nl) + 1
+        u(3) = l(3) + 1
 #endif /* NDIMS == 3 */
-      do j = jl, ju
-        jm1 = j - 1
-        jp1 = j + 1
-        jc  = 2 * (j - jl) + 1
-        jp  = jc + 1
-        do i = il, iu
-          im1 = i - 1
-          ip1 = i + 1
-          ic  = 2 * (i - il) + 1
-          ip  = ic + 1
+        do j = nl, nu
+          jm   = j - 1
+          jp   = j + 1
+          l(2) = 2 * (j - nl) + 1
+          u(2) = l(2) + 1
+          do i = nl, nu
+            im   = i - 1
+            ip   = i + 1
+            l(1) = 2 * (i - nl) + 1
+            u(1) = l(1) + 1
 
-          do p = 1, nv
-
-            dul   = pdata%u(p,i  ,j,k) - pdata%u(p,im1,j,k)
-            dur   = pdata%u(p,ip1,j,k) - pdata%u(p,i  ,j,k)
+            dul   = pdata%u(n,i ,j,k) - pdata%u(n,im,j,k)
+            dur   = pdata%u(n,ip,j,k) - pdata%u(n,i ,j,k)
             du(1) = limiter_prol(2.5d-01, dul, dur)
 
-            dul   = pdata%u(p,i,j  ,k) - pdata%u(p,i,jm1,k)
-            dur   = pdata%u(p,i,jp1,k) - pdata%u(p,i,j  ,k)
+            dul   = pdata%u(n,i,j ,k) - pdata%u(n,i,jm,k)
+            dur   = pdata%u(n,i,jp,k) - pdata%u(n,i,j ,k)
             du(2) = limiter_prol(2.5d-01, dul, dur)
 
 #if NDIMS == 3
-            dul   = pdata%u(p,i,j,k  ) - pdata%u(p,i,j,km1)
-            dur   = pdata%u(p,i,j,kp1) - pdata%u(p,i,j,k  )
+            dul   = pdata%u(n,i,j,k ) - pdata%u(n,i,j,km)
+            dur   = pdata%u(n,i,j,kp) - pdata%u(n,i,j,k )
             du(3) = limiter_prol(2.5d-01, dul, dur)
 #endif /* NDIMS == 3 */
 
-            if (positive(p) .and. pdata%u(p,i,j,k) < sum(abs(du(1:NDIMS)))) then
-              if (pdata%u(p,i,j,k) > 0.0d+00) then
-                do while (pdata%u(p,i,j,k) <= sum(abs(du(1:NDIMS))))
+            if (positive(n) .and. pdata%u(n,i,j,k) < sum(abs(du(1:NDIMS)))) then
+              if (pdata%u(n,i,j,k) > 0.0d+00) then
+                do while (pdata%u(n,i,j,k) <= sum(abs(du(1:NDIMS))))
                   du(:) = 0.5d+00 * du(:)
                 end do
               else
-                write(error_unit,"('[',a,']: ',a,3i4,a)") trim(loc)            &
-                     , "Positive variable is not positive at (", i, j, k, " )"
+                write(error_unit,"('[',a,']: ',a,3i4,a)") trim(loc),           &
+                      "Positive variable is not positive at (", i, j, k, " )"
                 du(:) = 0.0d+00
               end if
             end if
@@ -1227,10 +1224,10 @@ module mesh
 #if NDIMS == 2
             du1 = du(1) + du(2)
             du2 = du(1) - du(2)
-            up(p,ic,jc,kc) = pdata%u(p,i,j,k) - du1
-            up(p,ip,jc,kc) = pdata%u(p,i,j,k) + du2
-            up(p,ic,jp,kc) = pdata%u(p,i,j,k) - du2
-            up(p,ip,jp,kc) = pdata%u(p,i,j,k) + du1
+            work(l(1),l(2),k) = pdata%u(n,i,j,k) - du1
+            work(u(1),l(2),k) = pdata%u(n,i,j,k) + du2
+            work(l(1),u(2),k) = pdata%u(n,i,j,k) - du2
+            work(u(1),u(2),k) = pdata%u(n,i,j,k) + du1
 #endif /* NDIMS == 2 */
 
 #if NDIMS == 3
@@ -1238,52 +1235,37 @@ module mesh
             du2 = du(1) - du(2) - du(3)
             du3 = du(1) - du(2) + du(3)
             du4 = du(1) + du(2) - du(3)
-            up(p,ic,jc,kc) = pdata%u(p,i,j,k) - du1
-            up(p,ip,jc,kc) = pdata%u(p,i,j,k) + du2
-            up(p,ic,jp,kc) = pdata%u(p,i,j,k) - du3
-            up(p,ip,jp,kc) = pdata%u(p,i,j,k) + du4
-            up(p,ic,jc,kp) = pdata%u(p,i,j,k) - du4
-            up(p,ip,jc,kp) = pdata%u(p,i,j,k) + du3
-            up(p,ic,jp,kp) = pdata%u(p,i,j,k) - du2
-            up(p,ip,jp,kp) = pdata%u(p,i,j,k) + du1
+            work(l(1),l(2),l(3)) = pdata%u(n,i,j,k) - du1
+            work(u(1),l(2),l(3)) = pdata%u(n,i,j,k) + du2
+            work(l(1),u(2),l(3)) = pdata%u(n,i,j,k) - du3
+            work(u(1),u(2),l(3)) = pdata%u(n,i,j,k) + du4
+            work(l(1),l(2),u(3)) = pdata%u(n,i,j,k) - du4
+            work(u(1),l(2),u(3)) = pdata%u(n,i,j,k) + du3
+            work(l(1),u(2),u(3)) = pdata%u(n,i,j,k) - du2
+            work(u(1),u(2),u(3)) = pdata%u(n,i,j,k) + du1
 #endif /* NDIMS == 3 */
           end do
         end do
+#if NDIMS == 3
       end do
-#if NDIMS == 3
-    end do
 #endif /* NDIMS == 3 */
 
-    do p = 1, nchildren
+      do p = 1, nchildren
 
-      pchild => pmeta%child(p)%ptr
+        pchild => pmeta%child(p)%ptr
 
-      ic = pchild%pos(1)
-      jc = pchild%pos(2)
-#if NDIMS == 3
-      kc = pchild%pos(3)
-#endif /* NDIMS == 3 */
-
-      il = 1 + ic * ni
-      jl = 1 + jc * ni
-#if NDIMS == 3
-      kl = 1 + kc * ni
-#endif /* NDIMS == 3 */
-
-      iu = il + nn - 1
-      ju = jl + nn - 1
-#if NDIMS == 3
-      ku = kl + nn - 1
-#endif /* NDIMS == 3 */
+        l(1:NDIMS) = pchild%pos(1:NDIMS) * ni + 1
+        u(1:NDIMS) =          l(1:NDIMS) + nn - 1
 
 #if NDIMS == 2
-      pchild%data%u(1:nv,:,:,:) = up(1:nv,il:iu,jl:ju,  :  )
+        pchild%data%u(n,:,:,1) = work(l(1):u(1),l(2):u(2),k)
 #endif /* NDIMS == 2 */
 #if NDIMS == 3
-      pchild%data%u(1:nv,:,:,:) = up(1:nv,il:iu,jl:ju,kl:ku)
+        pchild%data%u(n,:,:,:) = work(l(1):u(1),l(2):u(2),l(3):u(3))
 #endif /* NDIMS == 3 */
 
-    end do ! nchildren
+      end do ! nchildren
+    end do ! n = 1, nv
 
 #ifdef PROFILE
     call stop_timer(imp)