!!******************************************************************************
!!
!!  This file is part of the AMUN source code, a program to perform
!!  Newtonian or relativistic magnetohydrodynamical simulations on uniform or
!!  adaptive mesh.
!!
!!  Copyright (C) 2008-2024 Grzegorz Kowal <grzegorz@amuncode.org>
!!
!!  This program is free software: you can redistribute it and/or modify
!!  it under the terms of the GNU General Public License as published by
!!  the Free Software Foundation, either version 3 of the License, or
!!  (at your option) any later version.
!!
!!  This program is distributed in the hope that it will be useful,
!!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!!  GNU General Public License for more details.
!!
!!  You should have received a copy of the GNU General Public License
!!  along with this program.  If not, see <http://www.gnu.org/licenses/>.
!!
!!******************************************************************************
!!
!! module: REFINEMENT
!!
!!  This module handles the error estimation and refinement criterion
!!  determination.
!!
!!
!!******************************************************************************
!
module refinement

  implicit none

  abstract interface
    real(kind=4) function user_criterion_iface(pdata) result(crit)
      use blocks, only : block_data
      type(block_data), pointer, intent(in) :: pdata
    end function
  end interface

  procedure(user_criterion_iface), pointer, save :: user_criterion => null()

  real(kind=8), save :: crefmin = 2.0d-01
  real(kind=8), save :: crefmax = 8.0d-01
  real(kind=8), save :: vortmin = 1.0d-03
  real(kind=8), save :: vortmax = 1.0d-01
  real(kind=8), save :: currmin = 1.0d-03
  real(kind=8), save :: currmax = 1.0d-01
  real(kind=8), save :: usermin = 1.0d+99
  real(kind=8), save :: usermax = 1.0d+99
  real(kind=8), save :: epsref  = 1.0d-02

  logical, dimension(:), allocatable, save :: qvar_ref
  logical                           , save :: vort_ref = .false.
  logical                           , save :: curr_ref = .false.
  logical                           , save :: user_ref = .false.

  private

  public :: initialize_refinement, finalize_refinement, print_refinement
  public :: check_refinement_criterion
  public :: user_criterion

!- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!
  contains
!
!===============================================================================
!!
!!***  PUBLIC SUBROUTINES  *****************************************************
!!
!===============================================================================
!
!===============================================================================
!
! subroutine INITIALIZE_REFINEMENT:
! --------------------------------
!
!   Subroutine initializes module REFINEMENT.
!
!   Arguments:
!
!     verbose - flag determining if the subroutine should be verbose;
!     status  - return flag of the procedure execution status;
!
!===============================================================================
!
  subroutine initialize_refinement(verbose, status)

    use equations , only : magnetized, nv, pvars
    use helpers   , only : print_message
    use parameters, only : get_parameter

    implicit none

    logical, intent(in)  :: verbose
    integer, intent(out) :: status

    logical            :: test
    integer            :: p
    character(len=255) :: variables = "dens pres"

    character(len=*), parameter :: loc = 'REFINEMENT::initialize_refinement()'

!-------------------------------------------------------------------------------
!
    status = 0

    call get_parameter("crefmin", crefmin)
    call get_parameter("crefmax", crefmax)
    call get_parameter("vortmin", vortmin)
    call get_parameter("vortmax", vortmax)
    call get_parameter("currmin", currmin)
    call get_parameter("currmax", currmax)
    call get_parameter("usermin", usermin)
    call get_parameter("usermax", usermax)
    call get_parameter("epsref" , epsref )

    call get_parameter("refinement_variables", variables)

    allocate(qvar_ref(nv), stat=status)
    if (status /= 0) &
        call print_message(loc, &
                 "Could not allocate the space for refinement criterion names!")

    do p = 1, nv
      qvar_ref(p) = index(variables, trim(pvars(p))) > 0
    end do

    vort_ref = index(variables, 'vort') > 0

    test = any(qvar_ref(:)) .or. vort_ref

    if (magnetized) then
      curr_ref = index(variables, 'curr') > 0 .or. index(variables, 'jabs') > 0
      test = test .or. curr_ref
    end if

    user_ref = index(variables, 'user') > 0
    test     = test .or. user_ref

    if (.not. test) then
      if (verbose) &
        call print_message(loc, "No refinement criterion has been selected!")
      status = 1
      return
    end if

    if (crefmin > crefmax .or. crefmin < 0.0d+00) then
      if (verbose) &
        call print_message(loc, &
                     "Wrong 'crefmin' or 'crefmax' parameters. " // &
                     "They should be positive and 'crefmin' <= 'crefmax'.")
      status = 1
      return
    end if

    if (vortmin > vortmax .or. vortmin < 0.0d+00) then
      if (verbose) &
        call print_message(loc, &
                     "Wrong 'vortmin' or 'vortmax' parameters. " // &
                     "They should be positive and 'vortmin' <= 'vortmax'.")
      status = 1
      return
    end if

    if ((currmin > currmax .or. currmin < 0.0d+00) .and. magnetized) then
      if (verbose) &
        call print_message(loc, &
                     "Wrong 'currmin' or 'currmax' parameters. " // &
                     "They should be positive and 'currmin' <= 'currmax'.")
      status = 1
      return
    end if

    if ((usermin > usermax .or. usermin < 0.0d+00) .and. user_ref) then
      if (verbose) &
        call print_message(loc, &
                     "Wrong 'usermin' or 'usermax' parameters. " // &
                     "They should be positive and 'usermin' <= 'usermax'.")
      status = 1
      return
    end if

    if (epsref <= 0.0d+00) then
      if (verbose) &
        call print_message(loc, &
                     "Wrong 'epsref' parameters. It should be positive.")
      status = 1
      return
    end if

!-------------------------------------------------------------------------------
!
  end subroutine initialize_refinement
!
!===============================================================================
!
! subroutine FINALIZE_REFINEMENT:
! ------------------------------
!
!   Subroutine releases memory used by the module variables.
!
!   Arguments:
!
!     status - return flag of the procedure execution status;
!
!===============================================================================
!
  subroutine finalize_refinement(status)

    use helpers, only : print_message

    implicit none

    integer, intent(out) :: status

    character(len=*), parameter :: loc = 'REFINEMENT:finalize_refinement()'

!-------------------------------------------------------------------------------
!
    status = 0

    if (allocated(qvar_ref)) then
      deallocate(qvar_ref, stat=status)
      if (status /= 0) &
        call print_message(loc, &
               "Could not deallocate space for the refinement criterion names!")
    end if

!-------------------------------------------------------------------------------
!
  end subroutine finalize_refinement
!
!===============================================================================
!
! subroutine PRINT_REFINEMENT:
! ---------------------------
!
!   Subroutine prints module parameters.
!
!   Arguments:
!
!     verbose - flag determining if the subroutine should be verbose;
!
!===============================================================================
!
  subroutine print_refinement(verbose)

    use helpers  , only : print_section, print_parameter
    use equations, only : magnetized, pvars, nv

    implicit none

    logical, intent(in) :: verbose

    character(len=80) :: rvars = ""
    integer           :: p

!-------------------------------------------------------------------------------
!
    if (.not. verbose) return

    rvars = ""
    do p = 1, nv
      if (qvar_ref(p)) rvars = adjustl(trim(rvars) // ' ' // trim(pvars(p)))
    end do
    if (vort_ref) then
      rvars = adjustl(trim(rvars) // ' vort')
    end if
    if (magnetized .and. curr_ref) then
      rvars = adjustl(trim(rvars) // ' curr')
    end if
    if (user_ref) then
      rvars = adjustl(trim(rvars) // ' user')
    end if

    call print_section(verbose, "Refinement")
    call print_parameter(verbose, "refined variables", rvars)
    call print_parameter(verbose, "2nd order error limits", crefmin, crefmax)
    if (vort_ref) &
      call print_parameter(verbose, "vorticity limits", vortmin, vortmax)
    if (magnetized .and. curr_ref) &
      call print_parameter(verbose, "current density limits", currmin, currmax)
    if (user_ref) &
      call print_parameter(verbose, "user criterion limits", usermin, usermax)

!-------------------------------------------------------------------------------
!
  end subroutine print_refinement
!
!===============================================================================
!
! function CHECK_REFINEMENT_CRITERION:
! -----------------------------------
!
!   Function scans the given data block and checks for the refinement
!   criterion.  It returns +1 if the criterion is met, which indicates that
!   the block needs to be refined, 0 if there is no need for the refinement,
!   and -1 if the block can be derefined.
!
!   Arguments:
!
!     pdata - pointer to the data block for which the refinement criterion
!             has to be determined;
!
!===============================================================================
!
  function check_refinement_criterion(pdata) result(criterion)

    use blocks   , only : block_data
    use equations, only : nv

    implicit none

    type(block_data), pointer, intent(in) :: pdata

    integer(kind=4) :: criterion

    integer      :: p
    real(kind=8) :: cref

!-------------------------------------------------------------------------------
!
    criterion = -1

    do p = 1, nv
      if (qvar_ref(p)) then

        cref = second_derivative_error(p, pdata)

        if (cref > crefmin) criterion = max(criterion, 0)
        if (cref > crefmax) criterion = max(criterion, 1)

      end if
    end do

    if (vort_ref) then
      cref = vorticity_magnitude(pdata)

      if (cref > vortmin) criterion = max(criterion, 0)
      if (cref > vortmax) criterion = max(criterion, 1)
    end if

    if (curr_ref) then
      cref = current_density_magnitude(pdata)

      if (cref > currmin) criterion = max(criterion, 0)
      if (cref > currmax) criterion = max(criterion, 1)
    end if

    if (user_ref .and. associated(user_criterion)) then
      cref = user_criterion(pdata)

      if (cref > usermin) criterion = max(criterion, 0)
      if (cref > usermax) criterion = max(criterion, 1)
    end if

    return

!-------------------------------------------------------------------------------
!
  end function check_refinement_criterion
!
!===============================================================================
!!
!!***  PRIVATE SUBROUTINES  ****************************************************
!!
!===============================================================================
!
!===============================================================================
!
! function SECOND_DERIVATIVE_ERROR:
! --------------------------------
!
!   Function calculate the second derivative error for a given data block
!   and selected primitive variables.  The total error is returned then.
!
!   Arguments:
!
!     iqt   - the index of primitive variable;
!     pdata - pointer to the data block for which error is calculated;
!
!===============================================================================
!
  function second_derivative_error(iqt, pdata) result(error)

    use blocks     , only : block_data
    use coordinates, only : nbl, neu

    implicit none

    integer                  , intent(in) :: iqt
    type(block_data), pointer, intent(in) :: pdata

    real(kind=8) :: error

    integer      :: i, im1, ip1
    integer      :: j, jm1, jp1
    integer      :: k
#if NDIMS == 3
    integer      :: km1, kp1
#endif /* NDIMS == 3 */
    real(kind=8) :: fl, fr, fc, fx, fy
#if NDIMS == 3
    real(kind=8) :: fz
#endif /* NDIMS == 3 */

    real(kind=8), parameter :: eps = epsilon(1.0d+00)

!-------------------------------------------------------------------------------
!
    error = 0.0e+00

#if NDIMS == 2
    k = 1
#endif /* NDIMS == 2 */

    if (iqt > 0) then

#if NDIMS == 3
      do k = nbl, neu
        km1 = k - 1
        kp1 = k + 1
#endif /* NDIMS == 3 */
        do j = nbl, neu
          jm1 = j - 1
          jp1 = j + 1
          do i = nbl, neu
            im1 = i - 1
            ip1 = i + 1

            fr   = pdata%q(iqt,ip1,j,k) - pdata%q(iqt,i  ,j,k)
            fl   = pdata%q(iqt,im1,j,k) - pdata%q(iqt,i  ,j,k)
            fc   = abs(pdata%q(iqt,ip1,j,k)) + abs(pdata%q(iqt,im1,j,k)) &
                                             + 2.0d+00 * abs(pdata%q(iqt,i,j,k))
            fx   = abs(fr + fl) / (abs(fr) + abs(fl) + epsref * fc + eps)

            fr   = pdata%q(iqt,i,jp1,k) - pdata%q(iqt,i,j  ,k)
            fl   = pdata%q(iqt,i,jm1,k) - pdata%q(iqt,i,j  ,k)
            fc   = abs(pdata%q(iqt,i,jp1,k)) + abs(pdata%q(iqt,i,jm1,k)) &
                                             + 2.0d+00 * abs(pdata%q(iqt,i,j,k))
            fy   = abs(fr + fl) / (abs(fr) + abs(fl) + epsref * fc + eps)

#if NDIMS == 3
            fr   = pdata%q(iqt,i,j,kp1) - pdata%q(iqt,i,j,k  )
            fl   = pdata%q(iqt,i,j,km1) - pdata%q(iqt,i,j,k  )
            fc   = abs(pdata%q(iqt,i,j,kp1)) + abs(pdata%q(iqt,i,j,km1)) &
                                             + 2.0d+00 * abs(pdata%q(iqt,i,j,k))
            fz   = abs(fr + fl) / (abs(fr) + abs(fl) + epsref * fc + eps)
#endif /* NDIMS == 3 */

#if NDIMS == 2
            error = max(error, fx, fy)
#endif /* NDIMS == 2 */
#if NDIMS == 3
            error = max(error, fx, fy, fz)
#endif /* NDIMS == 3 */

          end do
        end do
#if NDIMS == 3
      end do
#endif /* NDIMS == 3 */

    end if

    return

!-------------------------------------------------------------------------------
!
  end function second_derivative_error
!
!===============================================================================
!
! function VORTICITY_MAGNITUDE:
! ----------------------------
!
!   Function finds the maximum values of the vorticity magnitude
!   for the current data block.
!
!   Arguments:
!
!     pdata - pointer to the data block for which error is calculated;
!
!===============================================================================
!
  function vorticity_magnitude(pdata) result(wmax)

    use blocks     , only : block_data
    use coordinates, only : nn => bcells
    use coordinates, only : nbl, neu
    use equations  , only : ivx, ivz
    use helpers    , only : print_message
    use operators  , only : curl
    use workspace  , only : resize_workspace, work, work_in_use

    implicit none

    type(block_data), pointer, intent(in) :: pdata

    real(kind=4)  :: wmax

    logical, save :: first = .true.

    integer       :: i, j, k, status
    real(kind=8)  :: vort

    real(kind=8), dimension(3), save :: dh

    real(kind=8), dimension(:,:,:,:), pointer, save :: w

    integer, save :: nt
!$  integer :: omp_get_thread_num
!$omp threadprivate(first,nt,dh,w)

    character(len=*), parameter :: loc = 'REFINEMENT::vorticity_magnitude()'

!-------------------------------------------------------------------------------
!
    nt = 0
!$  nt = omp_get_thread_num()
    wmax = 0.0e+00

    if (first) then
      i = 3 * nn**NDIMS

      call resize_workspace(i, status)
      if (status /= 0) then
        call print_message(loc, "Could not resize the workspace!")
        return
      end if

#if NDIMS == 3
      w(1:3,1:nn,1:nn,1:nn) => work(1:i,nt)
#else /* NDIMS == 3 */
      w(1:3,1:nn,1:nn,1: 1) => work(1:i,nt)
#endif /* NDIMS == 3 */

      dh(:) = 1.0d+00

      first = .false.
    end if

    if (work_in_use(nt)) &
      call print_message(loc, &
                    "Workspace is being used right now! Corruptions can occur!")

    work_in_use(nt) = .true.

    call curl(dh(:), pdata%q(ivx:ivz,:,:,:), w(:,:,:,:))

#if NDIMS == 2
    k = 1
#endif /* NDIMS == 2 */
#if NDIMS == 3
    do k = nbl, neu
#endif /* NDIMS == 3 */
      do j = nbl, neu
        do i = nbl, neu

          vort = sum(w(:,i,j,k)**2)

          wmax = max(wmax, real(vort, kind=4))

        end do
      end do
#if NDIMS == 3
    end do
#endif /* NDIMS == 3 */

    work_in_use(nt) = .false.

    wmax = sqrt(wmax)

!-------------------------------------------------------------------------------
!
  end function vorticity_magnitude
!
!===============================================================================
!
! function CURRENT_DENSITY_MAGNITUDE:
! ----------------------------------
!
!   Function finds the maximum values of the current density magnitude
!   for the current data block.
!
!   Arguments:
!
!     pdata - pointer to the data block for which error is calculated;
!
!===============================================================================
!
  function current_density_magnitude(pdata) result(jmax)

    use blocks     , only : block_data
    use coordinates, only : nn => bcells
    use coordinates, only : nbl, neu
    use equations  , only : magnetized
    use equations  , only : ibx, ibz
    use helpers    , only : print_message
    use operators  , only : curl
    use workspace  , only : resize_workspace, work, work_in_use

    implicit none

    type(block_data), pointer, intent(in) :: pdata

    real(kind=4)  :: jmax

    logical, save :: first = .true.

    integer       :: i, j, k, status
    real(kind=8)  :: jabs

    real(kind=8), dimension(3), save :: dh

    real(kind=8), dimension(:,:,:,:), pointer, save :: w

    integer, save :: nt
!$  integer :: omp_get_thread_num
!$omp threadprivate(first,nt,dh,w)

    character(len=*), parameter :: loc = &
                                       'REFINEMENT::current_density_magnitude()'

!-------------------------------------------------------------------------------
!
    jmax = 0.0e+00

    if (.not. magnetized) return

    nt = 0
!$  nt = omp_get_thread_num()

    if (first) then
      i = 3 * nn**NDIMS

      call resize_workspace(i, status)
      if (status /= 0) then
        call print_message(loc, "Could not resize the workspace!")
        return
      end if

#if NDIMS == 3
      w(1:3,1:nn,1:nn,1:nn) => work(1:i,nt)
#else /* NDIMS == 3 */
      w(1:3,1:nn,1:nn,1: 1) => work(1:i,nt)
#endif /* NDIMS == 3 */

      dh(:) = 1.0d+00

      first = .false.
    end if

    if (work_in_use(nt)) &
      call print_message(loc, &
                    "Workspace is being used right now! Corruptions can occur!")

    work_in_use(nt) = .true.

    call curl(dh(:), pdata%q(ibx:ibz,:,:,:), w(:,:,:,:))

#if NDIMS == 2
    k = 1
#endif /* NDIMS == 2 */
#if NDIMS == 3
    do k = nbl, neu
#endif /* NDIMS == 3 */
      do j = nbl, neu
        do i = nbl, neu

          jabs = sum(w(:,i,j,k)**2)

          jmax = max(jmax, real(jabs, kind=4))

        end do
      end do
#if NDIMS == 3
    end do
#endif /* NDIMS == 3 */

    work_in_use(nt) = .false.

    jmax = sqrt(jmax)

!-------------------------------------------------------------------------------
!
  end function current_density_magnitude

!===============================================================================
!
end module refinement