/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#ifndef HPK_ALIGNEDALLOCATOR_HPP_INCLUDED
#define HPK_ALIGNEDALLOCATOR_HPP_INCLUDED

/// \file
/// \brief This header defines the class AlignedAllocator and the functions
///        allocateMemory() and allocateScratch().

#include <cstdlib>      // std::size_t, std::aligned_alloc
#include <memory>       // std::unique_ptr, std::allocator_traits
#include <new>          // std::bad_alloc, std::bad_array_new_length
#include <type_traits>  // std::true_type, std::is_trivially_destructible
#include <utility>      // std::forward, std::declval

namespace hpk {

#ifdef __cpp_lib_hardware_interference_size
using ::std::hardware_constructive_interference_size;
#else
inline constexpr std::size_t hardware_constructive_interference_size = 64;
#endif

/// \brief An allocator that provides aligned memory.
///
/// `AlignedAllocator` allocates memory using `std::aligned_alloc` and
/// satisfies the C++ named requirements for an Allocator.
///
/// The default alignment is `std::hardware_constructive_interference_size`
/// if that is defined; otherwise, 64 bytes.
///
/// Examples:
///
///     // Construct a std::vector whose data is (the default) 64B-aligned and
///     // initialize it with 10 floats of value 0.0f.  (Note that the type of
///     // v1 is not std::vector<float>, as that uses std::allocator.)
///     auto v1 = std::vector<float, hpk::AlignedAllocator<float>>(10);
///
///     // Construct an empty std::vector whose data will be 32B-aligned.
///     // (Note that the types of v1 and v2 are not the same.)
///     auto v2 = std::vector<float, hpk::AlignedAllocator<float, 32>>();
///
template<class T, std::size_t align = hardware_constructive_interference_size>
struct AlignedAllocator {
    using value_type = T;  ///< `T`, a cv-unqualified object type
    using is_always_equal = std::true_type;

    static_assert(align >= sizeof(void*),
                  "AlignedAllocator requires align >= sizeof(void*).");
    static_assert(align >= alignof(T),
                  "AlignedAllocator<T> requires align >= alignof(T).");
    static_assert((align & (align - 1)) == 0,
                  "AlignedAllocator requires align be a power of two.");

    template<class U> struct rebind {
        using other = AlignedAllocator<U, align>;
    };

    /// Default constructor
    constexpr AlignedAllocator() noexcept {}

    /// Copy constructor
    template<class U>
    constexpr AlignedAllocator(const AlignedAllocator<U, align>&) noexcept {}

    /// \brief Allocates uninitialized storage suitable for an array object
    ///        of type `T`.
    ///
    /// Allocates aligned memory for `n` elements of type `T`, but does not
    /// construct array elements.  If `n == 0`, returns `nullptr`.
    /// If allocation fails and exceptions are enabled, an exception is thrown.
    /// If allocation fails and exceptions are disabled (e.g., by a compiler
    /// flag), then `nullptr` is returned.
    [[nodiscard]] T* allocate(std::size_t n) const {
        if (n == 0) {
            return nullptr;
        }
        if (n > (static_cast<std::size_t>(-1) - align + 1) / sizeof(T)) {
#ifdef __cpp_exceptions
            throw std::bad_array_new_length();
#else
            return nullptr;
#endif
        }
        std::size_t nbytes = (n * sizeof(T) + align - 1) & -align;
        void* ptr = std::aligned_alloc(align, nbytes);
#ifdef __cpp_exceptions
        if (!ptr) {
            throw std::bad_alloc();
        }
#endif
        return static_cast<T*>(ptr);
    }

    /// \brief Deallocates storage pointed to by `p`, which was a value
    ///        returned by a previous call to `allocate()`.
    ///
    /// Note that this does not call the destructor of the object pointed
    /// to by `p`.
    void deallocate(T* p, std::size_t) const noexcept { std::free(p); }
};

/// \brief Returns `true`.
/// The storage allocated by any `AlignedAllocator` can be deallocated through
/// another one, regardless of template parameters.
/// \related AlignedAllocator
template<class L, std::size_t alignL, class R, std::size_t alignR>
constexpr bool operator==(const AlignedAllocator<L, alignL>&,
                          const AlignedAllocator<R, alignR>&) noexcept {
    return true;
}

/// \brief Returns `false`.
/// \related AlignedAllocator
template<class L, std::size_t alignL, class R, std::size_t alignR>
constexpr bool operator!=(const AlignedAllocator<L, alignL>&,
                          const AlignedAllocator<R, alignR>&) noexcept {
    return false;
}

// ========================================================================

/// \brief Used by smart pointers to deallocate memory.
///
/// This class provides a custom deleter for deallocating memory, used in
/// the unique_ptr constructed by the function `allocateMemory()`.
/// Note that the `value_type` of the Allocator must be trivially destructible
/// (e.g., a non-class type compatible with the C language) since no destructor
/// is called before deallocation.
template<class Allocator> class Deleter {
 public:
    using value_type = typename std::allocator_traits<Allocator>::value_type;
    using size_type = typename std::allocator_traits<Allocator>::size_type;
    using pointer = typename std::allocator_traits<Allocator>::pointer;

    static_assert(std::is_trivially_destructible_v<value_type>,
                  "Allocator::value_type must be trivially destructible");

    /// Default constructor.
    Deleter() noexcept(noexcept(Allocator())) : n(0) {}

    /// Constructs a `Deleter` given an Allocator and number of elements.
    Deleter(const Allocator& a, size_type n) noexcept : alloc(a), n(n) {}

    /// Function call operator deallocates storage.
    void operator()(pointer ptr) {
        if (n) {
            std::allocator_traits<Allocator>::deallocate(alloc, ptr, n);
        }
    }

 private:
    Allocator alloc;
    size_type n;
};

/// \brief Allocates memory for `n` elements of type `T`
/// \return `std::unique_ptr` that owns the allocated memory.
///
/// More specifically, the return value is a
///     `std::unique_ptr<T, hpk::Deleter<Allocator>>`.
/// If Allocator is not specified, `hpk::AlignedAllocator<T>` is the default.
/// Note that `T` must be trivially destructible (e.g., a non-class type
/// compatible with the C language).
///
/// Examples:
///
///     // Allocate (the default) 64B-aligned memory for 10 floats.
///     auto tmp1 = hpk::allocateMemory<float>(10);
///
///     // Allocate 8B-aligned memory for 20 floats.
///     auto tmp2 =
///         hpk::allocateMemory<float, hpk::AlignedAllocator<float, 8>>(20);
///
///     // Allocate memory for 30 floats using the standard allocator.
///     auto tmp3 = hpk::allocateMemory<float, std::allocator<float>>(30);
///
///     // Use an instance to allocate 128B-aligned memory for 40 doubles.
///     hpk::AlignedAllocator<double, 128> alloc;
///     auto tmp4 = hpk::allocateMemory<double>(40, alloc);
///
template<typename T, class Allocator = AlignedAllocator<T>>
[[nodiscard]] auto allocateMemory(std::size_t n,
                                  Allocator&& alloc = Allocator()) {
    using AllocatorType = std::decay_t<Allocator>;

    static_assert(std::is_trivially_destructible_v<T>,
                  "T must be trivially destructible");
    static_assert(std::is_same_v<typename AllocatorType::value_type, T>,
                  "Allocator::value_type must be T");
    if (n) {
        Deleter<AllocatorType> deleter(alloc, n);
        auto buf = std::allocator_traits<AllocatorType>::allocate(alloc, n);
        if (buf) {
            T* ptr = new (static_cast<void*>(buf)) T[n];
            return std::unique_ptr<T, Deleter<AllocatorType>>(ptr, deleter);
        }
    }
    Deleter<AllocatorType> deleter(alloc, 0);
    return std::unique_ptr<T, Deleter<AllocatorType>>(nullptr, deleter);
}

// ========================================================================

// Primary template for `is_allocator` type trait
template<typename, typename = void> struct is_allocator : std::false_type {};

// Specialization if `T` is an Allocator
template<typename T>
struct is_allocator<T,
                    std::void_t<typename std::remove_reference_t<T>::value_type,
                                decltype(std::declval<T>().deallocate(
                                        std::declval<T>().allocate(0), 0))>>
    : std::true_type {};

// True if the type `T` is an Allocator
template<typename T>
inline constexpr bool is_allocator_v = is_allocator<T>::value;

// Primary template for `uses_scratch` type trait
template<typename, typename = void> struct uses_scratch : std::false_type {};

// Specialization if `T` uses scratch memory
template<typename T>
struct uses_scratch<T,
                    std::void_t<typename std::remove_reference_t<T>::mathType,
                                decltype(std::declval<T>().scratchSize())>>
    : std::true_type {};

// True if the type `T` uses scratch memory
template<typename T>
inline constexpr bool uses_scratch_v = uses_scratch<T>::value;

// ========================================================================

/// \brief Allocates scratch memory for any of the compute arguments.
/// \return `std::unique_ptr` that owns the allocated memory.
///
/// This function finds the maximum of each compute object's `scratchSize()`
/// and then allocates scratch memory using the default `AlignedAllocator`.
///
/// Example:
///
///     // Given three pointers to FFT compute objects, the following
///     // allocates scratch memory which is suitable for any of them:
///     auto scratch = hpk::allocateScratch(*fft1, *fft2, *fft3);
///
template<typename... Args,
         typename std::enable_if_t<std::conjunction_v<uses_scratch<Args>...>,
                                   int> = 0>
[[nodiscard]] auto allocateScratch(const Args&... args) {
    using fp_t = typename decltype((
            std::enable_if<true, typename Args::mathType>{}, ...))::type;
    static_assert(
            std::conjunction_v<std::is_same<typename Args::mathType, fp_t>...>,
            "All compute objects must have the same mathType.");
    std::size_t tmp, max = 0;
    (((tmp = args.scratchSize()) > max ? max = tmp : std::size_t{0}), ...);
    return ::hpk::allocateMemory<fp_t>(max);
}

// Allocates scratch memory (of `min` reals or more) for `arg` using `alloc`.
template<typename Arg, typename Allocator,
         typename std::enable_if_t<uses_scratch_v<Arg>, int> = 0>
[[nodiscard]] auto allocateScratch(std::size_t min, const Arg& arg,
                                   Allocator&& alloc) {
    static_assert(is_allocator_v<Allocator>,
                  "Last argument to allocateScratch() must be an Allocator");
    using fp_t = typename std::remove_reference_t<Allocator>::value_type;
    static_assert(std::is_same_v<fp_t, typename Arg::mathType>,
                  "Allocator's value_type must be compute object's mathType");
    std::size_t size = arg.scratchSize();
    if (min > size) size = min;
    return ::hpk::allocateMemory<fp_t>(size, std::forward<Allocator>(alloc));
}

// Allocates scratch memory (of `min` reals or more) for any of the compute
// arguments using the Allocator that is supplied as the final argument.
template<typename Arg1, typename Arg2, typename... Args,
         typename std::enable_if_t<uses_scratch_v<Arg1> && uses_scratch_v<Arg2>,
                                   int> = 0>
[[nodiscard]] auto allocateScratch(std::size_t min, const Arg1& arg1,
                                   Arg2&& arg2, Args&&... args) {
    static_assert(
            std::is_same_v<typename std::remove_reference_t<Arg1>::mathType,
                           typename std::remove_reference_t<Arg2>::mathType>,
            "All compute objects must have the same mathType.");
    std::size_t size = arg1.scratchSize();
    if (min > size) size = min;
    return ::hpk::allocateScratch(size, std::forward<Arg2>(arg2),
                                  std::forward<Args>(args)...);
}

/// \brief Allocates scratch memory for any of the compute arguments using
///        the Allocator that is supplied as the final argument.
/// \return `std::unique_ptr` that owns the allocated memory.
///
/// This function finds the maximum of each compute object's `scratchSize()`
/// and then allocates scratch memory using an Allocator instance.
///
/// Example:
///
///     // Given three pointers to double precision compute objects, the
///     // following allocates scratch memory suitable for any of them:
///     hpk::AlignedAllocator<double> alloc;
///     auto scratch = hpk::allocateScratch(*fft1, *fft2, *fft3, alloc);
///
template<typename Arg, typename... Args,
         typename std::enable_if_t<std::disjunction_v<is_allocator<Args>...>
                                           && uses_scratch_v<Arg>,
                                   int> = 0>
[[nodiscard]] auto allocateScratch(Arg&& arg, Args&&... args) {
    return ::hpk::allocateScratch(std::size_t{0}, std::forward<Arg>(arg),
                                  std::forward<Args>(args)...);
}

}  // namespace hpk

#endif  // HPK_ALIGNEDALLOCATOR_HPP_INCLUDED
