!!******************************************************************************
!!
!!  This file is part of the AMUN source code, a program to perform
!!  Newtonian or relativistic magnetohydrodynamical simulations on uniform or
!!  adaptive mesh.
!!
!!  Copyright (C) 2012-2020 Yann Collet
!!  Copyright (C) 2020-2022 Grzegorz Kowal <grzegorz@amuncode.org>
!!
!!  This program is free software: you can redistribute it and/or modify
!!  it under the terms of the GNU General Public License as published by
!!  the Free Software Foundation, either version 3 of the License, or
!!  (at your option) any later version.
!!
!!  This program is distributed in the hope that it will be useful,
!!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!!  GNU General Public License for more details.
!!
!!  You should have received a copy of the GNU General Public License
!!  along with this program.  If not, see <http://www.gnu.org/licenses/>.
!!
!!******************************************************************************
!!
!! module: HASH
!!
!!  This module is an interface to the XXH functions by Yann Collet provided
!!  by the library libxxhash. If this library is not available, an internal
!!  Fortran implementation of the 64-bit version of the xxHash64 is used.
!!  The Fortran implementation is based on the XXH64 specification published at
!!      https://github.com/Cyan4973/xxHash/blob/dev/doc/xxhash_spec.md
!!
!!  For additional info, see
!!      http://www.xxhash.com or https://github.com/Cyan4973/xxHash
!!
!!******************************************************************************
!
module hash

  implicit none

#ifdef XXHASH
! interfaces to functions XXH64() and XXH3_64bits() provided by
! the library libxxhash
!
  interface
    integer(c_int64_t) function xxh64_lib(input, length, seed) &
                                                    bind(C, name="XXH64")
      use iso_c_binding, only: c_ptr, c_size_t, c_int64_t
      implicit none
      type(c_ptr)           , value :: input
      integer(kind=c_size_t), value :: length
      integer(c_int64_t)    , value :: seed
    end function xxh64_lib

    integer(c_int64_t) function xxh3_lib(input, length) &
                                                    bind(C, name="XXH3_64bits")
      use iso_c_binding, only: c_ptr, c_size_t, c_int64_t
      implicit none
      type(c_ptr)           , value :: input
      integer(kind=c_size_t), value :: length
    end function xxh3_lib
  end interface
#else /* XXHASH */
! hash parameters
!
  integer(kind=8), parameter :: prime1 = -7046029288634856825_8, &
                                prime2 = -4417276706812531889_8, &
                                prime3 =  1609587929392839161_8, &
                                prime4 = -8796714831421723037_8, &
                                prime5 =  2870177450012600261_8, &
                                prime6 =  6983438078262162902_8
#endif /* XXHASH */

! supported hash types
!
  enum, bind(c)
    enumerator hash_none
    enumerator hash_xxh64
#ifdef XXHASH
    enumerator hash_xxh3
#endif /* XXHASH */
  end enum

  private
  public :: hash_info, hash_name, digest, check_digest
  public :: digest_string, digest_integer

!- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!
  contains
!
!===============================================================================
!!
!!***  PUBLIC SUBROUTINES  *****************************************************
!!
!===============================================================================
!
!===============================================================================
!
! subroutine HASH_INFO:
! --------------------
!
!   Subroutine returns the hash ID and the length in bytes (characters)
!   by the provided hash name.
!
!===============================================================================
!
  subroutine hash_info(hash_name, hash_id, hash_length)

    use helpers, only : print_message

    implicit none

    character(len=*), intent(in)  :: hash_name
    integer         , intent(out) :: hash_id, hash_length

    character(len=*), parameter :: loc = "HASH::hash_info()"

!-------------------------------------------------------------------------------
!
    select case(trim(hash_name))
    case("xxh64", "XXH64")
      hash_id     = hash_xxh64
      hash_length = 16
#ifdef XXHASH
    case("xxh3", "XXH3")
      hash_id     = hash_xxh3
      hash_length = 16
#endif /* XXHASH */
    case("none")
      hash_id     = hash_none
      hash_length = 0
    case default
      call print_message(loc, &
                           "Hash '" // trim(hash_name) // "' is not supported!")
    end select

    return

!-------------------------------------------------------------------------------
!
  end subroutine hash_info
!
!===============================================================================
!
! function HASH_NAME:
! ------------------
!
!   Function returns the hash name by the provided hash ID.
!
!===============================================================================
!
  character(len=8) function hash_name(hash_type)

    implicit none

    integer, intent(in) :: hash_type

    integer(kind(hash_none)) :: htype

!-------------------------------------------------------------------------------
!
    htype = hash_type
    select case(htype)
    case(hash_xxh64)
      hash_name = "xxh64"
#ifdef XXHASH
    case(hash_xxh3)
      hash_name = "xxh3"
#endif /* XXHASH */
    case default
      hash_name = "none"
    end select

    return

!-------------------------------------------------------------------------------
!
  end function hash_name
!
!===============================================================================
!
! function DIGEST:
! ---------------
!
!   Function calculates the digest for a given sequence of bytes.
!
!   Arguments:
!
!     buffer  - the buffer pointer;
!     length  - the buffer length;
!     hash_id - the hash ID;
!
!===============================================================================
!
  integer(kind=8) function digest(buffer, length, hash_id) result(hash)

    use iso_c_binding, only : c_ptr

    implicit none

    type(c_ptr)    , intent(in) :: buffer
    integer(kind=8), intent(in) :: length
    integer        , intent(in) :: hash_id

!-------------------------------------------------------------------------------
!
    hash = 0
    select case(hash_id)
    case(hash_xxh64)
#ifndef XXHASH
      hash   = xxh64(buffer, length)
#else /* XXHASH */
      hash   = xxh64_lib(buffer, length, 0_8)
    case(hash_xxh3)
      hash   =  xxh3_lib(buffer, length)
#endif /* XXHASH */
    case default
    end select

    return

!-------------------------------------------------------------------------------
!
  end function digest
!
!===============================================================================
!
! subroutine CHECK_DIGEST:
! -----------------------
!
!   Subroutine checks if the provided digest matches the digest of
!   the input data.
!
!   Arguments:
!
!     loc     - the location of check;
!     fname   - the file name;
!     buffer  - the buffer pointer;
!     length  - the buffer length;
!     bdigest - the buffer digest to check;
!     hash_id - the hash ID;
!
!===============================================================================
!
  subroutine check_digest(loc, fname, buffer, length, bdigest, hash_id)

    use helpers      , only : print_message
    use iso_c_binding, only : c_ptr

    implicit none

    character(len=*), intent(in) :: loc, fname
    type(c_ptr)     , intent(in) :: buffer
    integer(kind=8) , intent(in) :: length
    integer(kind=8) , intent(in) :: bdigest
    integer         , intent(in) :: hash_id

!-------------------------------------------------------------------------------
!
    if (hash_id == hash_none) return

    if (bdigest /= digest(buffer, length, hash_id)) &
      call print_message(loc, trim(fname) // " seems to be corrupted!")

!-------------------------------------------------------------------------------
!
  end subroutine check_digest
!
!===============================================================================
!
! subroutine DIGEST_STRING:
! ------------------------
!
!   Subroutine converts the integer digest to string.
!
!   Arguments:
!
!     idigest - the digest as integer;
!     sdigest - the digest as string;
!
!===============================================================================
!
  subroutine digest_string(idigest, sdigest)

    use helpers, only : print_message

    implicit none

    integer(kind=8) , intent(in)    :: idigest
    character(len=*), intent(inout) :: sdigest

    character(len=*), parameter :: loc = "HASH::digest_string()"

!-------------------------------------------------------------------------------
!
    if (len(sdigest) >= 16) then
      write(sdigest,"(1z16.16)") idigest
    else
      call print_message(loc, &
                         "The string is too short to contain the whole digest!")
    end if

!-------------------------------------------------------------------------------
!
  end subroutine digest_string
!
!===============================================================================
!
! subroutine DIGEST_INTEGER:
! -------------------------
!
!   Subroutine converts the string digest to its integer representation.
!
!   Arguments:
!
!     sdigest - the digest as string;
!     idigest - the digest as integer;
!
!===============================================================================
!
  subroutine digest_integer(sdigest, idigest)

    implicit none

    character(len=*), intent(in)  :: sdigest
    integer(kind=8) , intent(out) :: idigest

!-------------------------------------------------------------------------------
!
    read(sdigest, fmt="(1z16)") idigest

!-------------------------------------------------------------------------------
!
  end subroutine digest_integer
!
!===============================================================================
!!
!!***  PRIVATE SUBROUTINES  ****************************************************
!!
!===============================================================================
!
#ifndef XXHASH
!===============================================================================
!
! function XXH64:
! --------------
!
!   Function calculates XXH64 hash for a given sequence of bytes.
!
!   Arguments:
!
!     buffer - the buffer pointer;
!     length - the buffer length;
!
!===============================================================================
!
  integer(kind=8) function xxh64(buffer, length) result(hash)

    use iso_c_binding, only : c_ptr, c_f_pointer

    implicit none

    type(c_ptr)    , intent(in) :: buffer
    integer(kind=8), intent(in) :: length

    integer(kind=8) :: remaining, offset

    integer(kind=8), dimension(4) :: lane, chunk

    integer(kind=1), dimension(:), pointer :: input

!-------------------------------------------------------------------------------
!
    call c_f_pointer(buffer, input, [ length ])

    hash      = 0_8
    offset    = 1_8
    remaining = length

    if (remaining >= 32_8) then
      lane(1) =   prime6
      lane(2) =   prime2
      lane(3) =   0_8
      lane(4) = - prime1

      do while (remaining >= 32_8)
        chunk(1:4) = transfer(input(offset:offset+31), 1_8, 4)

        lane(1) = xxh64_round(lane(1), chunk(1))
        lane(2) = xxh64_round(lane(2), chunk(2))
        lane(3) = xxh64_round(lane(3), chunk(3))
        lane(4) = xxh64_round(lane(4), chunk(4))

        offset    = offset    + 32_8
        remaining = remaining - 32_8
      end do

      hash = xxh64_rotl(lane(1),  1) + xxh64_rotl(lane(2),  7) +               &
             xxh64_rotl(lane(3), 12) + xxh64_rotl(lane(4), 18)

      hash = xxh64_merge(hash, lane(1))
      hash = xxh64_merge(hash, lane(2))
      hash = xxh64_merge(hash, lane(3))
      hash = xxh64_merge(hash, lane(4))

    else
      hash = prime5
    end if

    hash = hash + length

    do while (remaining >= 8_8)
      chunk(1) = transfer(input(offset:offset+7), 1_8)
      hash = ieor(hash, xxh64_round(0_8, chunk(1)))
      hash = xxh64_rotl(hash, 27)
      hash = hash * prime1 + prime4

      offset    = offset    + 8_8
      remaining = remaining - 8_8
    end do

    if (remaining >= 4_8) then
      chunk(1) = transfer((/ input(offset:offset+3), 0_1, 0_1, 0_1, 0_1 /), 1_8)
      hash = ieor(hash, chunk(1) * prime1)
      hash = xxh64_rotl(hash, 23)
      hash = hash * prime2 + prime3

      offset    = offset    + 4_8
      remaining = remaining - 4_8
    end if

    do while (remaining > 0_8)
      chunk(1) = transfer((/ input(offset), 0_1, 0_1, 0_1,                     &
                                            0_1, 0_1, 0_1, 0_1 /), 1_8)
      hash = ieor(hash, chunk(1) * prime5)
      hash = xxh64_rotl(hash, 11)
      hash = hash * prime1

      offset    = offset    + 1_8
      remaining = remaining - 1_8
    end do

    hash = xxh64_aval(hash)

    return

!-------------------------------------------------------------------------------
!
  end function xxh64
!
!===============================================================================
!
! function XXH64_ROUND:
! --------------------
!
!   Function processes one stripe of the input data updating
!   the correponding lane.
!
!   Arguments:
!
!     lane  - the lane;
!     input - the 8-byte data to process;
!
!===============================================================================
!
  integer(kind=8) function xxh64_round(lane, input)

    implicit none

    integer(kind=8), intent(in) :: lane, input

!-------------------------------------------------------------------------------
!
    xxh64_round = lane + (input * prime2)
    xxh64_round = xxh64_rotl(xxh64_round, 31)
    xxh64_round = xxh64_round * prime1
    return

!-------------------------------------------------------------------------------
!
  end function xxh64_round
!
!===============================================================================
!
! function XXH64_MERGE:
! --------------------
!
!   Function performs merging of the given lane in to the hash.
!
!   Arguments:
!
!     hash - the hash to merge to;
!     lane - the lane being merged;
!
!===============================================================================
!
  integer(kind=8) function xxh64_merge(hash, lane)

    implicit none

    integer(kind=8), intent(in) :: hash, lane

!-------------------------------------------------------------------------------
!
    xxh64_merge = ieor(hash, xxh64_round(0_8, lane))
    xxh64_merge = xxh64_merge * prime1 + prime4
    return

!-------------------------------------------------------------------------------
!
  end function xxh64_merge
!
!===============================================================================
!
! function XXH64_AVAL:
! -------------------
!
!   Function calculates the final mix of the hash.
!
!   Arguments:
!
!     hash   - the hash to mix;
!
!===============================================================================
!
  integer(kind=8) function xxh64_aval(hash)

    implicit none

    integer(kind=8), intent(in) :: hash

!-------------------------------------------------------------------------------
!
    xxh64_aval = hash
    xxh64_aval = ieor(xxh64_aval, ishft(xxh64_aval, -33)) * prime2
    xxh64_aval = ieor(xxh64_aval, ishft(xxh64_aval, -29)) * prime3
    xxh64_aval = ieor(xxh64_aval, ishft(xxh64_aval, -32))
    return

!-------------------------------------------------------------------------------
!
  end function xxh64_aval
!
!===============================================================================
!
! function XXH64_ROTL:
! -------------------
!
!   Function calculates the rotation of the input 8-byte word by a given amount.
!
!   Arguments:
!
!     byte   - the byte to be rotates;
!     amount - the amount by which rotate the input byte;
!
!===============================================================================
!
  integer(kind=8) function xxh64_rotl(byte, amount)

    implicit none

    integer(kind=8), intent(in) :: byte
    integer(kind=4), intent(in) :: amount

!-------------------------------------------------------------------------------
!
    xxh64_rotl = ior(ishft(byte, amount), ishft(byte, amount - 64))
    return

!-------------------------------------------------------------------------------
!
  end function xxh64_rotl
#endif /* ~XXHASH */

!===============================================================================
!
end module hash