!!******************************************************************************
!!
!!  This file is part of the AMUN source code, a program to perform
!!  Newtonian or relativistic magnetohydrodynamical simulations on uniform or
!!  adaptive mesh.
!!
!!  Copyright (C) 2008-2020 Grzegorz Kowal <grzegorz@amuncode.org>
!!
!!  This program is free software: you can redistribute it and/or modify
!!  it under the terms of the GNU General Public License as published by
!!  the Free Software Foundation, either version 3 of the License, or
!!  (at your option) any later version.
!!
!!  This program is distributed in the hope that it will be useful,
!!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!!  GNU General Public License for more details.
!!
!!  You should have received a copy of the GNU General Public License
!!  along with this program.  If not, see <http://www.gnu.org/licenses/>.
!!
!!******************************************************************************
!!
!! module: SOURCES
!!
!!  This modules adds source terms.
!!
!!******************************************************************************
!
module sources

#ifdef PROFILE
! include external procedures
!
  use timers, only : set_timer, start_timer, stop_timer
#endif /* PROFILE */

! module variables are not implicit by default
!
  implicit none

#ifdef PROFILE
! timer indices
!
  integer, save :: imi, imu
#endif /* PROFILE */

! GLM-MHD source terms type (1 - EGLM, 2 - HEGLM)
!
  integer          , save :: glm_type    = 0
  character(len=32), save :: glm_name = "none"

! viscosity coefficient
!
  real(kind=8)     , save :: viscosity   = 0.0d+00

! resistivity coefficient
!
  real(kind=8)     , save :: resistivity = 0.0d+00
  real(kind=8)     , save :: anomalous   = 0.0d+00
  real(kind=8)     , save :: jcrit       = 1.0d+00

! by default everything is private
!
  private

! declare public subroutines
!
  public :: initialize_sources, finalize_sources, print_sources
  public :: update_sources
  public :: viscosity, resistivity

!- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!
  contains
!
!===============================================================================
!!
!!***  PUBLIC SUBROUTINES  *****************************************************
!!
!===============================================================================
!
!===============================================================================
!
! subroutine INITIALIZE_SOURCES:
! -----------------------------
!
!   Subroutine initializes module SOURCES.
!
!   Arguments:
!
!     verbose - a logical flag turning the information printing;
!     status  - an integer flag for error return value;
!
!===============================================================================
!
  subroutine initialize_sources(verbose, status)

! include external procedures and variables
!
    use parameters , only : get_parameter

! local variables are not implicit by default
!
    implicit none

! subroutine arguments
!
    logical, intent(in)  :: verbose
    integer, intent(out) :: status

! local variables
!
    character(len=8) :: tglm = "none"
!
!-------------------------------------------------------------------------------
!
#ifdef PROFILE
! set timer descriptions
!
    call set_timer('sources:: initialize', imi)
    call set_timer('sources:: update'    , imu)

! start accounting time for module initialization/finalization
!
    call start_timer(imi)
#endif /* PROFILE */

! reset the status flag
!
    status = 0

! get the type of the GLM source terms
!
    call get_parameter("glm_source_terms", tglm)

! set the glm_type variable to correct value
!
    select case(trim(tglm))
    case("eglm", "EGLM")
      glm_type = 1
      glm_name = "EGLM"
    case("heglm", "HEGLM")
      glm_type = 2
      glm_name = "HEGLM"
    case default
      glm_type = 0
      glm_name = "none"
    end select

! get viscosity coefficient
!
      call get_parameter("viscosity", viscosity)

      if (viscosity < 0.0d+00) then
        if (verbose) then
          write(*,*)
          write(*,"(1x,a)") "ERROR!"
          write(*,"(1x,a)") "Negative viscosity coefficient!"
        end if
        status = 1
      end if

! get resistivity coefficients
!
      call get_parameter("resistivity", resistivity)

      if (resistivity < 0.0d+00) then
        if (verbose) then
          write(*,*)
          write(*,"(1x,a)") "ERROR!"
          write(*,"(1x,a)") "Negative resistivity coefficient!"
        end if
        status = 1
      end if

      call get_parameter("anomalous", anomalous)

      if (anomalous < 0.0d+00) then
        if (verbose) then
          write(*,*)
          write(*,"(1x,a)") "ERROR!"
          write(*,"(1x,a)") "Negative anomalous resistivity coefficient!"
        end if
        status = 1
      end if

      call get_parameter("jcrit", jcrit)

      if (jcrit <= 0.0d+00) then
        if (verbose) then
          write(*,*)
          write(*,"(1x,a)") "ERROR!"
          write(*,"(1x,a)") "Non-positive critical current density coefficient!"
        end if
        status = 1
      end if

#ifdef PROFILE
! stop accounting time for module initialization/finalization
!
    call stop_timer(imi)
#endif /* PROFILE */

!-------------------------------------------------------------------------------
!
  end subroutine initialize_sources
!
!===============================================================================
!
! subroutine FINALIZE_SOURCES:
! ---------------------------
!
!   Subroutine releases memory used by the module.
!
!   Arguments:
!
!     status - an integer flag for error return value;
!
!===============================================================================
!
  subroutine finalize_sources(status)

! local variables are not implicit by default
!
    implicit none

! subroutine arguments
!
    integer, intent(out) :: status
!
!-------------------------------------------------------------------------------
!
#ifdef PROFILE
! start accounting time for module initialization/finalization
!
    call start_timer(imi)
#endif /* PROFILE */

! reset the status flag
!
    status = 0

#ifdef PROFILE
! stop accounting time for module initialization/finalization
!
    call stop_timer(imi)
#endif /* PROFILE */

!-------------------------------------------------------------------------------
!
  end subroutine finalize_sources
!
!===============================================================================
!
! subroutine PRINT_SOURCES:
! ------------------------
!
!   Subroutine prints module parameters.
!
!   Arguments:
!
!     verbose - a logical flag turning the information printing;
!
!===============================================================================
!
  subroutine print_sources(verbose)

! include external procedures
!
    use equations, only : magnetized
    use helpers  , only : print_section, print_parameter

! local variables are not implicit by default
!
    implicit none

! subroutine arguments
!
    logical, intent(in) :: verbose
!
!-------------------------------------------------------------------------------
!
    if (verbose) then

      call print_section(verbose, "Source terms")
      call print_parameter(verbose, "viscosity"       , viscosity  )
      if (magnetized) then
        call print_parameter(verbose, "resistivity"     , resistivity)
        if (anomalous /= 0.0d+00) then
          call print_parameter(verbose, "anomalous"       , anomalous  )
          call print_parameter(verbose, "jcrit"           , jcrit      )
        end if
        call print_parameter(verbose, "glm source terms", glm_name   )
      end if

    end if

!-------------------------------------------------------------------------------
!
  end subroutine print_sources
!
!===============================================================================
!
! subroutine UPDATE_SOURCES:
! -------------------------
!
!   Subroutine add the source terms.
!
!   Arguments:
!
!     pdata - the pointer to a data block;
!     t, dt - the time and time increment;
!     du    - the array of variable increment;
!
!===============================================================================
!
  subroutine update_sources(pdata, t, dt, du)

! include external variables
!
    use blocks         , only : block_data
    use coordinates    , only : nn => bcells
    use coordinates    , only : ax, ay, az, adx, ady, adz
    use equations      , only : nv, inx, iny, inz
    use equations      , only : idn, ivx, ivy, ivz, imx, imy, imz, ien
    use equations      , only : ibx, iby, ibz, ibp
    use gravity        , only : gravity_enabled, gravitational_acceleration
    use operators      , only : divergence, gradient, laplace, curl
    use user_problem   , only : update_sources_user

! local variables are not implicit by default
!
    implicit none

! subroutine arguments
!
    type(block_data), pointer       , intent(inout) :: pdata
    real(kind=8)                    , intent(in)    :: t, dt
    real(kind=8), dimension(:,:,:,:), intent(inout) :: du

! local variables
!
    integer       :: i, j, k = 1
    real(kind=8)  :: fc, gc
    real(kind=8)  :: gx, gy, gz
    real(kind=8)  :: dbx, dby, dbz
    real(kind=8)  :: dvxdx, dvxdy, dvxdz, divv
    real(kind=8)  :: dvydx, dvydy, dvydz
    real(kind=8)  :: dvzdx, dvzdy, dvzdz

! local arrays
!
    real(kind=8), dimension(3)  :: ga, dh
    real(kind=8), dimension(nn) :: x, y
#if NDIMS == 3
    real(kind=8), dimension(nn) :: z
    real(kind=8), dimension(nn,nn,nn)     :: db
    real(kind=8), dimension(3,3,nn,nn,nn) :: tmp
#else /* NDIMS == 3 */
    real(kind=8), dimension( 1) :: z
    real(kind=8), dimension(nn,nn, 1)     :: db
    real(kind=8), dimension(3,3,nn,nn, 1) :: tmp
#endif /* NDIMS == 3 */
!
!-------------------------------------------------------------------------------
!
#ifdef PROFILE
! start accounting time for source terms
!
    call start_timer(imu)
#endif /* PROFILE */

! proceed only if the gravitational term is enabled
!
    if (gravity_enabled) then

! prepare block coordinates
!
      x(:) = pdata%meta%xmin + ax(pdata%meta%level,:)
      y(:) = pdata%meta%ymin + ay(pdata%meta%level,:)
#if NDIMS == 3
      z(:) = pdata%meta%zmin + az(pdata%meta%level,:)
#endif /* NDIMS == 3 */

! iterate over all positions in the YZ plane
!
#if NDIMS == 3
      do k = 1, nn
#endif /* NDIMS == 3 */
        do j = 1, nn
          do i = 1, nn

! get gravitational acceleration components
!
            call gravitational_acceleration(t, dt, x(i), y(j), z(k), ga(1:3))

! calculate the gravitational source terms
!
            gx = pdata%q(idn,i,j,k) * ga(1)
            gy = pdata%q(idn,i,j,k) * ga(2)
#if NDIMS == 3
            gz = pdata%q(idn,i,j,k) * ga(3)
#endif /* NDIMS == 3 */

! add source terms to momentum equations
!
            du(imx,i,j,k) = du(imx,i,j,k) + gx
            du(imy,i,j,k) = du(imy,i,j,k) + gy
#if NDIMS == 3
            du(imz,i,j,k) = du(imz,i,j,k) + gz
#endif /* NDIMS == 3 */

! add source terms to total energy equation
!
            if (ien > 0) then

#if NDIMS == 2
              du(ien,i,j,k) = du(ien,i,j,k) + gx * pdata%q(ivx,i,j,k)          &
                                            + gy * pdata%q(ivy,i,j,k)
#endif /* NDIMS == 2 */
#if NDIMS == 3
              du(ien,i,j,k) = du(ien,i,j,k) + gx * pdata%q(ivx,i,j,k)          &
                                            + gy * pdata%q(ivy,i,j,k)          &
                                            + gz * pdata%q(ivz,i,j,k)
#endif /* NDIMS == 3 */

            end if

          end do ! i = 1, nn
        end do ! j = 1, nn
#if NDIMS == 3
      end do ! k = 1, nn
#endif /* NDIMS == 3 */

    end if ! gravity enabled

! proceed only if the viscosity coefficient is not zero
!
    if (viscosity > 0.0d+00) then

! prepare coordinate increments
!
      dh(1) = adx(pdata%meta%level)
      dh(2) = ady(pdata%meta%level)
      dh(3) = adz(pdata%meta%level)

! calculate the velocity Jacobian
!
      call gradient(dh(:), pdata%q(ivx,:,:,:), tmp(inx,inx:inz,:,:,:))
      call gradient(dh(:), pdata%q(ivy,:,:,:), tmp(iny,inx:inz,:,:,:))
      call gradient(dh(:), pdata%q(ivz,:,:,:), tmp(inz,inx:inz,:,:,:))

! iterate over all cells
!
#if NDIMS == 3
      do k = 1, nn
#endif /* NDIMS == 3 */
        do j = 1, nn
          do i = 1, nn

! prepare the νρ factor
!
            gc    = viscosity * pdata%q(idn,i,j,k)
            fc    = 2.0d+00 * gc

! get the velocity Jacobian elements
!
            dvxdx = tmp(inx,inx,i,j,k)
            dvxdy = tmp(inx,iny,i,j,k)
            dvxdz = tmp(inx,inz,i,j,k)
            dvydx = tmp(iny,inx,i,j,k)
            dvydy = tmp(iny,iny,i,j,k)
            dvydz = tmp(iny,inz,i,j,k)
            dvzdx = tmp(inz,inx,i,j,k)
            dvzdy = tmp(inz,iny,i,j,k)
            dvzdz = tmp(inz,inz,i,j,k)
            divv  = (dvxdx + dvydy + dvzdz) / 3.0d+00

! calculate elements of the viscous stress tensor
!
            tmp(inx,inx,i,j,k) = fc * (dvxdx - divv)
            tmp(iny,iny,i,j,k) = fc * (dvydy - divv)
            tmp(inz,inz,i,j,k) = fc * (dvzdz - divv)
            tmp(inx,iny,i,j,k) = gc * (dvxdy + dvydx)
            tmp(inx,inz,i,j,k) = gc * (dvxdz + dvzdx)
            tmp(iny,inz,i,j,k) = gc * (dvydz + dvzdy)
            tmp(iny,inx,i,j,k) = tmp(inx,iny,i,j,k)
            tmp(inz,inx,i,j,k) = tmp(inx,inz,i,j,k)
            tmp(inz,iny,i,j,k) = tmp(iny,inz,i,j,k)

          end do ! i = 1, nn
        end do ! j = 1, nn
#if NDIMS == 3
      end do ! k = 1, nn
#endif /* NDIMS == 3 */

! calculate the divergence of the first tensor row
!
      call divergence(dh(:), tmp(inx,inx:inz,:,:,:), db(:,:,:))

! add viscous source terms to the X momentum equation
!
      du(imx,:,:,:) = du(imx,:,:,:) + db(:,:,:)

! calculate the divergence of the second tensor row
!
      call divergence(dh(:), tmp(iny,inx:inz,:,:,:), db(:,:,:))

! add viscous source terms to the Y momentum equation
!
      du(imy,:,:,:) = du(imy,:,:,:) + db(:,:,:)

! calculate the divergence of the third tensor row
!
      call divergence(dh(:), tmp(inz,inx:inz,:,:,:), db(:,:,:))

! add viscous source terms to the Z momentum equation
!
      du(imz,:,:,:) = du(imz,:,:,:) + db(:,:,:)

! add viscous source term to total energy equation
!
      if (ien > 0) then

! iterate over all cells
!
#if NDIMS == 3
        do k = 1, nn
#endif /* NDIMS == 3 */
          do j = 1, nn
            do i = 1, nn

! calculate scalar product of v and viscous stress tensor τ
!
              gx = pdata%q(ivx,i,j,k) * tmp(inx,inx,i,j,k)                     &
                 + pdata%q(ivy,i,j,k) * tmp(inx,iny,i,j,k)                     &
                 + pdata%q(ivz,i,j,k) * tmp(inx,inz,i,j,k)
              gy = pdata%q(ivx,i,j,k) * tmp(iny,inx,i,j,k)                     &
                 + pdata%q(ivy,i,j,k) * tmp(iny,iny,i,j,k)                     &
                 + pdata%q(ivz,i,j,k) * tmp(iny,inz,i,j,k)
              gz = pdata%q(ivx,i,j,k) * tmp(inz,inx,i,j,k)                     &
                 + pdata%q(ivy,i,j,k) * tmp(inz,iny,i,j,k)                     &
                 + pdata%q(ivz,i,j,k) * tmp(inz,inz,i,j,k)

! update (v.τ), use the first row of the tensor tmp
!
              tmp(inx,inx,i,j,k) = gx
              tmp(inx,iny,i,j,k) = gy
              tmp(inx,inz,i,j,k) = gz

            end do ! i = 1, nn
          end do ! j = 1, nn
#if NDIMS == 3
        end do ! k = 1, nn
#endif /* NDIMS == 3 */

! calculate the divergence of (v.τ)
!
        call divergence(dh(:), tmp(inx,inx:inz,:,:,:), db(:,:,:))

! update the energy increment
!
        du(ien,:,:,:) = du(ien,:,:,:) + db(:,:,:)

      end if ! ien > 0

    end if ! viscosity is not zero

!=== add magnetic field related source terms ===
!
    if (ibx > 0) then

! prepare coordinate increments
!
      dh(1) = adx(pdata%meta%level)
      dh(2) = ady(pdata%meta%level)
      dh(3) = adz(pdata%meta%level)

! add the EGLM-MHD source terms
!
      if (glm_type > 0) then

! calculate the magnetic field divergence
!
        call divergence(dh(:), pdata%q(ibx:ibz,:,:,:), db(:,:,:))

! update the momentum component increments, i.e.
!     d/dt (ρv) + ∇.F = - (∇.B)B
!
        du(imx,:,:,:) = du(imx,:,:,:) - db(:,:,:) * pdata%q(ibx,:,:,:)
        du(imy,:,:,:) = du(imy,:,:,:) - db(:,:,:) * pdata%q(iby,:,:,:)
        du(imz,:,:,:) = du(imz,:,:,:) - db(:,:,:) * pdata%q(ibz,:,:,:)

! update the energy equation
!
        if (ien > 0 .and. ibp > 0) then

! calculate the gradient of divergence potential
!
          call gradient(dh(:), pdata%q(ibp,:,:,:), tmp(inx:inz,inx,:,:,:))

! add the divergence potential source term to the energy equation, i.e.
!     d/dt E + ∇.F = - B.(∇ψ)
!
          du(ien,:,:,:) = du(ien,:,:,:)                                        &
                     - sum(pdata%q(ibx:ibz,:,:,:) * tmp(inx:inz,inx,:,:,:), 1)

        end if ! ien > 0

! add the HEGLM-MHD source terms
!
        if (glm_type > 1) then

! update magnetic field component increments, i.e.
!     d/dt B + ∇.F = - (∇.B)v
!
          du(ibx,:,:,:) = du(ibx,:,:,:) - db(:,:,:) * pdata%q(ivx,:,:,:)
          du(iby,:,:,:) = du(iby,:,:,:) - db(:,:,:) * pdata%q(ivy,:,:,:)
          du(ibz,:,:,:) = du(ibz,:,:,:) - db(:,:,:) * pdata%q(ivz,:,:,:)

! update the energy equation
!
          if (ien > 0) then

! calculate scalar product of velocity and magnetic field
!
            tmp(inx,inx,:,:,:) = sum(pdata%q(ivx:ivz,:,:,:)                    &
                                                  * pdata%q(ibx:ibz,:,:,:), 1)

! add the divergence potential source term to the energy equation, i.e.
!     d/dt E + ∇.F = - (∇.B) (v.B)
!
            du(ien,:,:,:) = du(ien,:,:,:) - db(:,:,:) * tmp(inx,inx,:,:,:)

          end if ! ien > 0

        end if ! glm_type > 1

      end if ! glmtype > 0

! if anomalous resistivity is enabled
!
      if (anomalous > 0.0d+00) then

! calculate current density (J = ∇xB)
!
        call curl(dh(:), pdata%q(ibx:ibz,:,:,:), tmp(inx:inz,inx,:,:,:))

! calculate the normalized absolute value of current density (|J|/Jcrit)
!
        tmp(inx,iny,:,:,:) = sqrt(sum(tmp(inx:inz,inx,:,:,:)**2, 1)) / jcrit

! calculate the local resistivity [ηu + ηa (|J|/Jcrit - 1) H(|J|/Jcrit)]
!
        tmp(iny,iny,:,:,:) = resistivity +                                     &
                      anomalous * max(0.0d+00, (tmp(inx,iny,:,:,:) - 1.0d+00))

! multiply the current density vector by the local resistivity (ηJ)
!
        tmp(inx,inz,:,:,:) = tmp(iny,iny,:,:,:) * tmp(inx,inx,:,:,:)
        tmp(iny,inz,:,:,:) = tmp(iny,iny,:,:,:) * tmp(iny,inx,:,:,:)
        tmp(inz,inz,:,:,:) = tmp(iny,iny,:,:,:) * tmp(inz,inx,:,:,:)

! calculate the curl of (ηJ)
!
        call curl(dh(:), tmp(inx:inz,inz,:,:,:), tmp(inx:inz,iny,:,:,:))

! update magnetic field component increments
!
        du(ibx,:,:,:) = du(ibx,:,:,:) - tmp(inx,iny,:,:,:)
        du(iby,:,:,:) = du(iby,:,:,:) - tmp(iny,iny,:,:,:)
        du(ibz,:,:,:) = du(ibz,:,:,:) - tmp(inz,iny,:,:,:)

! update energy equation
!
        if (ien > 0) then

! calculate the vector product Bx(η ∇xB)
!
          tmp(inx,iny,:,:,:) = pdata%q(iby,:,:,:) * tmp(inz,inz,:,:,:)         &
                             - pdata%q(ibz,:,:,:) * tmp(iny,inz,:,:,:)
          tmp(iny,iny,:,:,:) = pdata%q(ibz,:,:,:) * tmp(inx,inz,:,:,:)         &
                             - pdata%q(ibx,:,:,:) * tmp(inz,inz,:,:,:)
          tmp(inz,iny,:,:,:) = pdata%q(ibx,:,:,:) * tmp(iny,inz,:,:,:)         &
                             - pdata%q(iby,:,:,:) * tmp(inx,inz,:,:,:)

! calculate the divergence ∇.[Bx(η ∇xB)]
!
          call divergence(dh(:), tmp(inx:inz,iny,:,:,:), db(:,:,:))

! add the second resistive source term to the energy equation, i.e.
!     d/dt E + ∇.F = η J²
!
          du(ien,:,:,:) = du(ien,:,:,:) + db(:,:,:)

        end if ! energy equation present

      else if (resistivity > 0.0d+00) then

! calculate the Laplace operator of B, i.e. Δ(B)
!
        call laplace(dh(:), pdata%q(ibx,:,:,:), tmp(inx,inx,:,:,:))
        call laplace(dh(:), pdata%q(iby,:,:,:), tmp(inx,iny,:,:,:))
        call laplace(dh(:), pdata%q(ibz,:,:,:), tmp(inx,inz,:,:,:))

! multiply by the resistivity coefficient
!
        tmp(iny,inx:inz,:,:,:) = resistivity * tmp(inx,inx:inz,:,:,:)

! update magnetic field component increments
!
        du(ibx,:,:,:) = du(ibx,:,:,:) + tmp(iny,inx,:,:,:)
        du(iby,:,:,:) = du(iby,:,:,:) + tmp(iny,iny,:,:,:)
        du(ibz,:,:,:) = du(ibz,:,:,:) + tmp(iny,inz,:,:,:)

! update energy equation
!
        if (ien > 0) then

! add the first resistive source term to the energy equation, i.e.
!     d/dt E + ∇.F = η B.[Δ(B)]
!
          du(ien,:,:,:) = du(ien,:,:,:)                                        &
                        + (pdata%q(ibx,:,:,:) * tmp(iny,inx,:,:,:)             &
                        +  pdata%q(iby,:,:,:) * tmp(iny,iny,:,:,:)             &
                        +  pdata%q(ibz,:,:,:) * tmp(iny,inz,:,:,:))

! calculate current density J = ∇xB
!
          call curl(dh(:), pdata%q(ibx:ibz,:,:,:), tmp(inz,inx:inz,:,:,:))

! calculate J²
!
          db(:,:,:) = tmp(inz,inx,:,:,:)**2 + tmp(inz,iny,:,:,:)**2            &
                                            + tmp(inz,inz,:,:,:)**2

! add the second resistive source term to the energy equation, i.e.
!     d/dt E + ∇.F = η J²
!
          du(ien,:,:,:) = du(ien,:,:,:) + resistivity * db(:,:,:)

        end if ! energy equation present

      end if ! resistivity is not zero

    end if ! ibx > 0

! add user defined source terms
!
    call update_sources_user(pdata, t, dt, du(:,:,:,:))

#ifdef PROFILE
! stop accounting time for source terms
!
    call stop_timer(imu)
#endif /* PROFILE */

!-------------------------------------------------------------------------------
!
  end subroutine update_sources

!===============================================================================
!
end module sources