!!******************************************************************************
!!
!!  This file is part of the AMUN source code, a program to perform
!!  Newtonian or relativistic magnetohydrodynamical simulations on uniform or
!!  adaptive mesh.
!!
!!  Copyright (C) 2008-2021 Grzegorz Kowal <grzegorz@amuncode.org>
!!
!!  This program is free software: you can redistribute it and/or modify
!!  it under the terms of the GNU General Public License as published by
!!  the Free Software Foundation, either version 3 of the License, or
!!  (at your option) any later version.
!!
!!  This program is distributed in the hope that it will be useful,
!!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!!  GNU General Public License for more details.
!!
!!  You should have received a copy of the GNU General Public License
!!  along with this program.  If not, see <http://www.gnu.org/licenses/>.
!!
!!******************************************************************************
!!
!! module: MPITOOLS
!!
!!  This module provides wrapper subroutines handling the parallel execution
!!  with the Message Passing Interface protocol.
!!
!!
!!******************************************************************************
!
module mpitools

#ifdef MPI
  use mpi_f08
#endif /* MPI */
  use timers, only : set_timer, start_timer, stop_timer

  implicit none

! subroutine interfaces
!
#ifdef MPI
  interface reduce_minimum
    module procedure reduce_minimum_double_array
  end interface
  interface reduce_maximum
    module procedure reduce_maximum_integer
#ifndef __NVCOMPILER
    module procedure reduce_maximum_double
#endif /* __NVCOMPILER */
    module procedure reduce_maximum_double_array
  end interface
  interface reduce_sum
    module procedure reduce_sum_integer_array
    module procedure reduce_sum_double_array
    module procedure reduce_sum_complex_array
  end interface
  interface exchange_arrays
    module procedure exchange_arrays_diff
    module procedure exchange_arrays_same
  end interface

! timer indices
!
  integer, save :: imi, imc
#endif /* MPI */

! MPI global variables
!
  integer(kind=4), save :: nproc, nprocs, nodes, node, lprocs, lproc
  integer(kind=4), save :: npmax, lpmax, npairs
  logical        , save :: master = .true.

! allocatable array for processor pairs
!
  integer(kind=4), dimension(:,:), allocatable, save :: pairs

! by default everything is private
!
  private

! declare public subroutines
!
  public :: initialize_mpitools, finalize_mpitools
  public :: check_status
#ifdef MPI
  public :: reduce_minimum, reduce_maximum, reduce_sum
  public :: send_array, receive_array
  public :: exchange_arrays
#endif /* MPI */

! declare public variables
!
  public :: master, nproc, nprocs, nodes, node, lprocs, lproc
  public :: npmax, lpmax, npairs, pairs

!- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!
  contains
!
!===============================================================================
!!
!!***  PUBLIC SUBROUTINES  *****************************************************
!!
!===============================================================================
!
! subroutine INITIALIZE_MPITOOLS:
! ------------------------------
!
!   Subroutine initializes the MPITOOLS modules.
!
!   Arguments:
!
!     status - the return value; if it is 0 everything went successfully,
!              otherwise there was a problem;
!
!===============================================================================
!
  subroutine initialize_mpitools(status)

    use helpers, only : print_message

    implicit none

    integer, intent(out) :: status

#ifdef MPI
    type(MPI_Comm) :: comm
    integer        :: mprocs, i, j, l, n
    integer        :: ierror

    integer(kind=4), dimension(:), allocatable :: procs

    character(len=*), parameter :: loc = 'MPITOOLS::initialize_mpitools()'
#endif /* MPI */

!-------------------------------------------------------------------------------
!
#ifdef MPI
    call set_timer('MPI initialization', imi)
    call set_timer('MPI communication' , imc)

    call start_timer(imi)
#endif /* MPI */

    status = 0

    nproc  = 0
    nprocs = 1
    npmax  = 0
    npairs = 0
    nodes  = 1
    node   = 0
    lproc  = 0
    lprocs = 1
    lpmax  = 0

#ifdef MPI
    call MPI_Init(ierror)

    if (ierror == MPI_SUCCESS) then

      call MPI_Comm_size(MPI_COMM_WORLD, nprocs, ierror)

      if (ierror == MPI_SUCCESS) then

        call MPI_Comm_rank(MPI_COMM_WORLD, nproc, ierror)

        if (ierror == MPI_SUCCESS) then

          master = nproc == 0

          npmax  = nprocs - 1
          mprocs = nprocs + mod(nprocs, 2)
          npairs = nprocs * npmax / 2

          allocate(procs(mprocs), pairs(2 * npairs, 2), stat = status)

          if (status == 0) then

            procs(:) = (/(l, l = 0, mprocs - 1)/)

            n = 0

            do l = 1, mprocs - 1

              do i = 1, mprocs / 2

                j = mprocs - i + 1

                if (procs(i) < nprocs .and. procs(j) < nprocs) then

                  n = n + 1
                  pairs(n,1:2) = (/ procs(i), procs(j) /)

                end if ! max(procs(i), procs(j)) < nprocs

              end do ! i = 1, mprocs / 2

              procs(2:mprocs) = cshift(procs(2:mprocs), -1)

            end do ! l = 1, mprocs - 1

            pairs(npairs+1:2*npairs,1:2) = pairs(1:npairs,2:1:-1)

            deallocate(procs, stat = status)

          end if ! allocate

          call MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0,    &
                                   MPI_INFO_NULL, comm, ierror)
          if (ierror == MPI_SUCCESS) then
            call MPI_Comm_size(comm, lprocs, ierror)
            if (ierror /= MPI_SUCCESS) then
              call print_message(loc,                                          &
                                "Could not get the number of node processes!")
              status = 1
            end if
            lpmax = lprocs - 1
            nodes = nprocs / lprocs
            node  = nproc  / lprocs
            call MPI_Comm_rank(comm, lproc, ierror)
            if (ierror /= MPI_SUCCESS) then
              call print_message(loc,                                          &
                                "Could not get the node rank!")
              status = 1
            end if
          else
            call print_message(loc, "Could not split the MPI communicator!")
            status = 1
          end if

        else
          call print_message(loc, "Could not get the MPI process ID!")
          status = 1
        end if
      else
        call print_message(loc, "Could not get the number of MPI processes!")
        status = 1
      end if
    else
      call print_message(loc, "Could not initialize the MPI interface!")
      status = 1
    end if

    call stop_timer(imi)
#endif /* MPI */

!-------------------------------------------------------------------------------
!
  end subroutine initialize_mpitools
!
!===============================================================================
!
! subroutine FINALIZE_MPITOOLS:
! ----------------------------
!
!   Subroutine finalizes the MPITOOLS modules.
!
!   Arguments:
!
!     status - the return value; if it is 0 everything went successfully,
!              otherwise there was a problem;
!
!===============================================================================
!
  subroutine finalize_mpitools(status)

    use helpers, only : print_message

    implicit none

    integer, intent(out) :: status

#ifdef MPI
    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::finalize_mpitools()'
#endif /* MPI */

!-------------------------------------------------------------------------------
!
    status = 0

#ifdef MPI
    call start_timer(imi)

    if (allocated(pairs)) deallocate(pairs, stat = status)

    call MPI_Finalize(ierror)

    if (ierror /= MPI_SUCCESS) then
      call print_message(loc, "Could not finalize the MPI interface!")
      status = 1
    end if

    call stop_timer(imi)
#endif /* MPI */

!-------------------------------------------------------------------------------
!
  end subroutine finalize_mpitools
!
!===============================================================================
!
! subroutine CHECK_STATUS:
! -----------------------
!
!   Subroutine calculates the logical OR for input values from all MPI
!   processes, if MPI is used, otherwise, just returns the input value.
!
!   Arguments:
!
!     flag - the input logical flag;
!
!===============================================================================
!
  logical function check_status(flag)

    use helpers, only : print_message

    implicit none

    logical, intent(in) :: flag
#ifdef MPI

    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::check_status()'
#endif /* MPI */
!
!-------------------------------------------------------------------------------
!
#ifdef MPI
    call start_timer(imc)
#endif /* MPI */

    check_status = flag

#ifdef MPI
    call MPI_Allreduce(MPI_IN_PLACE, check_status, 1,                          &
                       MPI_LOGICAL, MPI_LOR, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce of logical buffer failed!")

    call stop_timer(imc)
#endif /* MPI */

!-------------------------------------------------------------------------------
!
  end function check_status
#ifdef MPI
!
!===============================================================================
!
! subroutine SEND_ARRAY:
! ---------------------
!
!   Subroutine sends an arrays of real values to another process.
!
!   Arguments:
!
!     dst - the ID of the destination process;
!     tag - the tag identifying this operation;
!     buf - the buffer of real values to send;
!
!===============================================================================
!
  subroutine send_array(dst, tag, buf)

    use helpers, only : print_message

    implicit none

    integer                    , intent(in) :: dst, tag
    real(kind=8), dimension(..), intent(in) :: buf

    integer :: ierror

    character(len=80) :: msg

    character(len=*), parameter :: loc = 'MPITOOLS::send_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)

    call MPI_Send(buf, size(buf), MPI_REAL8, dst, tag, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) then
      write(msg,"('Could not send a real array from ',i0,' to ',i0)") nproc, dst
      call print_message(loc, msg)
    end if

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine send_array
!
!===============================================================================
!
! subroutine RECEIVE_ARRAY:
! ------------------------
!
!   Subroutine receives an arrays of real values from another process.
!
!   Arguments:
!
!     src - the ID of the source process;
!     tag - the tag identifying this operation;
!     buf - the received real array;
!
!===============================================================================
!
  subroutine receive_array(src, tag, buf)

    use helpers, only : print_message

    implicit none

    integer                    , intent(in)  :: src, tag
    real(kind=8), dimension(..), intent(out) :: buf

    integer :: ierror

    character(len=80) :: msg

    character(len=*), parameter :: loc = 'MPITOOLS::receive_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)

    call MPI_Recv(buf, size(buf), MPI_REAL8, src, tag,                         &
                                  MPI_COMM_WORLD, MPI_STATUS_IGNORE, ierror)

    if (ierror /= MPI_SUCCESS) then
      write(msg,"('Could not receive a real array from ',i0,' to ',i0)")       &
                                                                    src, nproc
      call print_message(loc, msg)
    end if

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine receive_array
!
!===============================================================================
!
! subroutine EXCHANGE_ARRAYS_DIFF:
! -------------------------------
!
!   Subroutine exchanges real data buffers (of different sizes) between
!   two processes.
!
!   Arguments:
!
!     proc  - the remote process number to which send the buffer sbuf,
!             and from which receive the buffer rbuf;
!     tag   - the tag identifying the send operation;
!     sbuf  - the real array buffer to send;
!     rbuf  - the real array buffer to receive;
!
!===============================================================================
!
  subroutine exchange_arrays_diff(proc, tag, sbuf, rbuf)

    use helpers, only : print_message

    implicit none

    integer                    , intent(in)  :: proc, tag
    real(kind=8), dimension(..), intent(in)  :: sbuf
    real(kind=8), dimension(..), intent(out) :: rbuf

    integer :: ssize, rsize
    integer :: ierror

    character(len=80) :: msg

    character(len=*), parameter :: loc = 'MPITOOLS::exchange_arrays_diff()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)

    ssize = size(sbuf)
    rsize = size(rbuf)

    if (ssize > 0 .and. rsize > 0) then
      call MPI_Sendrecv(sbuf, ssize, MPI_REAL8, proc, tag,                     &
                        rbuf, rsize, MPI_REAL8, proc, tag,                     &
                        MPI_COMM_WORLD, MPI_STATUS_IGNORE, ierror)
    else
      if (ssize > 0) then
        call MPI_Send(sbuf, ssize, MPI_REAL8, proc, tag, MPI_COMM_WORLD,       &
                      ierror)
      else
        call MPI_Recv(rbuf, rsize, MPI_REAL8, proc, tag, MPI_COMM_WORLD,       &
                      MPI_STATUS_IGNORE, ierror)
      end if
    end if

    if (ierror /= MPI_SUCCESS) then
      write(msg,"('Could not exchange real data " //                           &
                               "buffers between ',i0,' and ',i0)") proc, nproc
      call print_message(loc, msg)
    end if

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine exchange_arrays_diff
!
!===============================================================================
!
! subroutine EXCHANGE_ARRAYS_SAME:
! -------------------------------
!
!   Subroutine exchanges a real data buffer between two processes.
!
!   Arguments:
!
!     proc  - the remote process number to which send the buffer sbuf,
!             and from which receive the buffer rbuf;
!     tag   - the tag identifying the send operation;
!     buf   - the real array buffer to exchange;
!
!===============================================================================
!
  subroutine exchange_arrays_same(proc, tag, buf)

    use helpers, only : print_message

    implicit none

    integer                    , intent(in)    :: proc, tag
    real(kind=8), dimension(..), intent(inout) :: buf

    integer :: ierror

    character(len=80) :: msg

    character(len=*), parameter :: loc = 'MPITOOLS::exchange_arrays_same()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Sendrecv_replace(buf, size(buf), MPI_REAL8, proc, tag, proc, tag, &
                              MPI_COMM_WORLD, MPI_STATUS_IGNORE, ierror)

    if (ierror /= MPI_SUCCESS) then
      write(msg,"('Could not exchange real data " //                           &
                               "buffer between ',i0,' and ',i0)") proc, nproc
      call print_message(loc, msg)
    end if

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine exchange_arrays_same
!
!===============================================================================
!!
!!***  PRIVATE SUBROUTINES  ****************************************************
!!
!===============================================================================
!
!===============================================================================
!
! subroutine REDUCE_MINIMUM_DOUBLE_ARRAY:
! --------------------------------------
!
!   Subroutine find the minimum value for each double precision array element
!   among the corresponding values from all processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_minimum_double_array(buf)

    use helpers, only : print_message

    implicit none

    real(kind=8), dimension(:), intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: &
                                loc = 'MPITOOLS::reduce_minimum_double_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Allreduce(MPI_IN_PLACE, buf, size(buf),                           &
                       MPI_REAL8, MPI_MIN, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce of a real array failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_minimum_double_array
!
!===============================================================================
!
! subroutine REDUCE_MAXIMUM_INTEGER:
! ---------------------------------
!
!   Subroutine find the maximum value among the integer values from all
!   processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_maximum_integer(buf)

    use helpers, only : print_message

    implicit none

    integer, intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::reduce_maximum_integer()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)

    call MPI_Allreduce(MPI_IN_PLACE, buf, 1,                                   &
                       MPI_INTEGER, MPI_MAX, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce of an integer value failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_maximum_integer
#ifndef __NVCOMPILER
!
!===============================================================================
!
! subroutine REDUCE_MAXIMUM_DOUBLE:
! --------------------------------
!
!   Subroutine find the maximum value among the double precision values
!   from all processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_maximum_double(buf)

    use helpers, only : print_message

    implicit none

    real(kind=8), intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::reduce_maximum_double()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Allreduce(MPI_IN_PLACE, buf, 1,                                   &
                       MPI_REAL8, MPI_MAX, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce of a real value failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_maximum_double
#endif /* __NVCOMPILER */
!
!===============================================================================
!
! subroutine REDUCE_MAXIMUM_DOUBLE_ARRAY:
! --------------------------------------
!
!   Subroutine find the maximum value for each double plrecision array element
!   among the corresponding values from all processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_maximum_double_array(buf)

    use helpers, only : print_message

    implicit none

    real(kind=8), dimension(:), intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: &
                               loc = 'MPITOOLS::reduce_maximum_double_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Allreduce(MPI_IN_PLACE, buf, size(buf),                           &
                       MPI_REAL8, MPI_MAX, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce of a real array failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_maximum_double_array
!
!===============================================================================
!
! subroutine REDUCE_SUM_INTEGER_ARRAY:
! -----------------------------------
!
!   Subroutine sums the values for each array element from the corresponding
!   values from all processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_sum_integer_array(buf)

    use helpers, only : print_message

    implicit none

    integer, dimension(:), intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::reduce_sum_integer_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Allreduce(MPI_IN_PLACE, buf, size(buf),                           &
                       MPI_INTEGER, MPI_SUM, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce of an integer real array failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_sum_integer_array
!
!===============================================================================
!
! subroutine REDUCE_SUM_DOUBLE_ARRAY:
! ----------------------------------
!
!   Subroutine sums the values for each double precision array element from
!   the corresponding values from all processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_sum_double_array(buf)

    use helpers, only : print_message

    implicit none

    real(kind=8), dimension(:), intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::reduce_sum_double_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Allreduce(MPI_IN_PLACE, buf, size(buf),                           &
                       MPI_REAL8, MPI_SUM, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce a real array failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_sum_double_array
!
!===============================================================================
!
! subroutine REDUCE_SUM_COMPLEX_ARRAY:
! -----------------------------------
!
!   Subroutine sums the values for each array element from the corresponding
!   complex values from all processes.
!
!   Argument:
!
!     buf - a buffer to be reduced;
!
!===============================================================================
!
  subroutine reduce_sum_complex_array(buf)

    use helpers, only : print_message

    implicit none

    complex(kind=8), dimension(:,:), intent(inout) :: buf

    integer :: ierror

    character(len=*), parameter :: loc = 'MPITOOLS::reduce_sum_complex_array()'

!-------------------------------------------------------------------------------
!
    call start_timer(imc)
    call MPI_Allreduce(MPI_IN_PLACE, buf, size(buf),                           &
                       MPI_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD, ierror)

    if (ierror /= MPI_SUCCESS) &
      call print_message(loc, "MPI_Allreduce a complex array failed!")

    call stop_timer(imc)

!-------------------------------------------------------------------------------
!
  end subroutine reduce_sum_complex_array
#endif /* MPI */

!===============================================================================
!
end module mpitools