/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#ifndef HPK_FFT_MAKEFACTORY_HPP_INCLUDED
#define HPK_FFT_MAKEFACTORY_HPP_INCLUDED

/// \file
/// \brief This header provides a function for making concrete factories.

#include <memory>
#include <string>
#include <type_traits>
#include <utility>

#ifndef HPK_FFT_NDLSYM
#include <dlfcn.h>  // dlopen, dlsym, dlclose
#endif

#include <hpk/configuration.hpp>
#include <hpk/detection.hpp>
#include <hpk/fft/factory.hpp>
#include <hpk/visibility.hpp>

/// The major version of this release of the Hpk library, as a C string.
#define HPK_MAJOR_VERSION_C_STR "0"

// The macro HPK_HAVE_FLOAT16_TYPE indicates whether the compiler supports the
// type _Float16.  It can be set to either 0 or 1 before including this file,
// otherwise it is set using the logic below.
#ifndef HPK_HAVE_FLOAT16_TYPE
#if defined(HPK_HAVE_FFT_AVX512_FP16) || defined(HPK_HAVE_FFT_SVE256_FP16) \
        || defined(__AVX512FP16__)
#define HPK_HAVE_FLOAT16_TYPE 1
#else
#ifdef __is_identifier  // clang language extension
#if __is_identifier(_Float16)
#define HPK_HAVE_FLOAT16_TYPE 0
#else
#define HPK_HAVE_FLOAT16_TYPE 1
#endif
#else
#define HPK_HAVE_FLOAT16_TYPE 0
#endif
#endif
#endif

namespace hpk {
namespace fft {

#ifndef HPK_FFT_NDLSYM
/// Library filename for AVX2 shared object supporting 32 bit precision.
inline constexpr char avx2_fp32_so[] =
        "libhpk_fft_avx2_fp32.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for AVX2 shared object supporting 64 bit precision.
inline constexpr char avx2_fp64_so[] =
        "libhpk_fft_avx2_fp64.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for AVX512_FP16 shared object supporting 16 bit precision.
inline constexpr char avx512_fp16_so[] =
        "libhpk_fft_avx512_fp16.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for AVX512 shared object supporting 32 bit precision.
inline constexpr char avx512_fp32_so[] =
        "libhpk_fft_avx512_fp32.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for AVX512 shared object supporting 64 bit precision.
inline constexpr char avx512_fp64_so[] =
        "libhpk_fft_avx512_fp64.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for SVE256 shared object supporting 16 bit precision.
inline constexpr char sve256_fp16_so[] =
        "libhpk_fft_sve256_fp16.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for SVE256 shared object supporting 32 bit precision.
inline constexpr char sve256_fp32_so[] =
        "libhpk_fft_sve256_fp32.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for SVE256 shared object supporting 64 bit precision.
inline constexpr char sve256_fp64_so[] =
        "libhpk_fft_sve256_fp64.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for shared object utilizing Intel's OpenMP library.
inline constexpr char iomp_so[] = "libhpk_fft_iomp.so." HPK_MAJOR_VERSION_C_STR;

/// Library filename for shared object utilizing LLVM's OpenMP library.
inline constexpr char omp_so[] = "libhpk_fft_omp.so." HPK_MAJOR_VERSION_C_STR;

// Helper function.
// Returns a pointer to a makeFactory function by obtaining the address from
// a symbol found in a shared object, or nullptr if the symbol is not found.
template<typename T, typename... Handles>
inline T getMakeFactoryPtr(const char* symbol, Handles... handles) {
    void* ptr;
    if constexpr (sizeof...(Handles) == 0) {
        ptr = dlsym(RTLD_DEFAULT, symbol);
    } else {
        ((ptr = dlsym(handles, symbol)) || ...);
    }
    if (ptr) {
        return *static_cast<T*>(ptr);
    }
    return nullptr;
}

// Helper function for creating symbol names.
template<class F> inline void appendSymbolSuffix(std::string& symbol) {
    if constexpr (std::is_same_v<typename F::mathType, float>)
        symbol.append("32");
    else if constexpr (std::is_same_v<typename F::mathType, double>)
        symbol.append("64");
#if HPK_HAVE_FLOAT16_TYPE
    else if constexpr (std::is_same_v<typename F::mathType, _Float16>)
        symbol.append("16");
#endif
    if constexpr (is_complex_v<typename F::freqType>) {
        if constexpr (is_complex_v<typename F::timeType>)
            symbol.append("_cc");
        else
            symbol.append("_rc");
    }
}

#endif  // #ifndef HPK_FFT_NDLSYM

namespace seq {

template<typename fp_t, typename time_t, typename freq_t>
HPK_API std::unique_ptr<hpk::fft::Factory<fp_t, time_t, freq_t>>
makeFactory_avx2(const Configuration& cfg);

template<typename fp_t, typename time_t, typename freq_t>
HPK_API std::unique_ptr<hpk::fft::Factory<fp_t, time_t, freq_t>>
makeFactory_avx512(const Configuration& cfg);

template<typename fp_t, typename time_t, typename freq_t>
HPK_API std::unique_ptr<hpk::fft::Factory<fp_t, time_t, freq_t>>
makeFactory_sve256(const Configuration& cfg);

template<class F, typename... Handles>
std::unique_ptr<F> makeFactory_dl([[maybe_unused]] Architecture arch,
                                  [[maybe_unused]] const Configuration& cfg,
                                  [[maybe_unused]] Handles... handles) {
#ifndef HPK_FFT_NDLSYM
    std::string symbol;
    if (arch >= Architecture::avx512)
        symbol = "hpk_fft_seq_makeFactory_avx512_fp";
    else if (arch == Architecture::avx2)
        symbol = "hpk_fft_seq_makeFactory_avx2_fp";
    else if (arch == Architecture::sve256)
        symbol = "hpk_fft_seq_makeFactory_sve256_fp";
    else
        return {};
    appendSymbolSuffix<F>(symbol);
    using makeFactory_ptr_t = std::unique_ptr<F> (*)(const Configuration&);
    auto makeFactory_ptr =
            getMakeFactoryPtr<makeFactory_ptr_t>(symbol.c_str(), handles...);
    if (makeFactory_ptr) {
        return makeFactory_ptr(cfg);
    }
#endif
    return {};
}

}  // namespace seq

namespace omp {

template<class F>
HPK_API std::unique_ptr<F> makeFactory(std::unique_ptr<F>&& seqFactory,
                                       const Configuration& cfg);
template<class F, typename... Handles>
std::unique_ptr<F> makeFactory_dl(std::unique_ptr<F>&& seqFactory,
                                  [[maybe_unused]] const Configuration& cfg,
                                  [[maybe_unused]] Handles... handles) {
#ifndef HPK_FFT_NDLSYM
    std::string symbol{"hpk_fft_omp_makeFactory_fp"};
    appendSymbolSuffix<F>(symbol);
    using makeFactory_ptr_t =
            std::unique_ptr<F> (*)(std::unique_ptr<F>&&, const Configuration&);
    auto makeFactory_ptr =
            getMakeFactoryPtr<makeFactory_ptr_t>(symbol.c_str(), handles...);
    if (makeFactory_ptr) {
        return makeFactory_ptr(std::move(seqFactory), cfg);
    }
#endif
    return std::move(seqFactory);
}

}  // namespace omp

/// \brief Makes a concrete instance of an hpk::fft::Factory.
/// \param cfg     Optionally provides a Configuration instance specifying
///                parameters and values.
/// \param handles Optionally provides dynamic shared object handles as
///                returned by `dlopen()`.
/// \return `std::unique_ptr` that owns a Factory or, in case of failure,
///         is empty.
///
/// The first template parameter must be specified.  It is the floating-point
/// math type used for FFT computations and it is the type of the scale factor
/// if scaling is later applied.  (If `math_t` is complex, `std::complex` will
/// be removed from the type.)
///
/// The second and third template parameters, which specify the data type in
/// the time domain and frequency domain respectively, default to the complex
/// number type of the first template parameter.
/// Therefore, if only the first template parameter is specified, the factory
/// will make FFTs that compute complex-to-complex transforms.
///
/// Examples:
///
///     // Single precision, complex time domain, complex freq domain
///     auto factory1 = hpk::fft::makeFactory<float>();
///
///     // As above, but uses the AVX2 library regardless of AVX512 presence.
///     // Moreover, this factory will be owned by a shared_ptr, and so copies
///     // can be made (e.g., using the shared_ptr copy constructor).
///     std::shared_ptr factory2 =
///             hpk::fft::makeFactory<float>({hpk::Architecture::avx2});
///
///     // As above, and also single-threaded.
///     hpk::Configuration cfg{hpk::Architecture::avx2, hpk::sequential};
///     std::shared_ptr factory3 = hpk::fft::makeFactory<float>(cfg);
///
///     // Double precision, real time domain, complex freq domain
///     std::unique_ptr factory4 =
///             hpk::fft::makeFactory<double, double, std::complex<double>>();
///
///     // Since freq_t defaults to complex, the following is exactly the same
///     // as factory4.  Although factory5 requires fewer keystrokes, it is
///     // probably less readable.
///     auto factory5 = hpk::fft::makeFactory<double, double>();
///
template<typename math_t, typename time_t = add_complex_t<math_t>,
         typename freq_t = add_complex_t<math_t>, typename... Handles>
[[nodiscard]] auto makeFactory(const Configuration& cfg = Configuration(),
                               Handles... handles) {
    using fp_t = remove_complex_t<std::remove_cv_t<math_t>>;
    using timeType = std::remove_cv_t<time_t>;
    using freqType = std::remove_cv_t<freq_t>;
    using Factory_t = Factory<fp_t, timeType, freqType>;
    static_assert((std::is_pointer_v<Handles> && ...));

    Architecture arch = cfg.getArchitecture();
    const bool archDetected = arch == Architecture::detect;
    if (archDetected) {
        arch = detectArchitecture();
    }

    std::unique_ptr<Factory_t> seqFactory;  // Owns an hpk::fft::seq factory.

#if HPK_HAVE_FLOAT16_TYPE
    if (arch == Architecture::avx512fp16) {
        if constexpr (std::is_same_v<fp_t, _Float16>) {
#if defined(HPK_HAVE_FFT_AVX512_FP16)
            seqFactory = seq::makeFactory_avx512<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(arch, cfg, handles...);
#endif
        }
    }
#endif  // HPK_HAVE_FLOAT16_TYPE

    if (arch >= Architecture::avx512) {
        if constexpr (std::is_same_v<fp_t, float>) {
#if defined(HPK_HAVE_FFT_AVX512_FP32)
            seqFactory = seq::makeFactory_avx512<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(arch, cfg, handles...);
#endif
        }
        if constexpr (std::is_same_v<fp_t, double>) {
#if defined(HPK_HAVE_FFT_AVX512_FP64)
            seqFactory = seq::makeFactory_avx512<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(arch, cfg, handles...);
#endif
        }
    }

    // On avx512 hardware, when the avx512 shared library is not available
    // drop down to the avx2 factory if and only if archDetected is true,
    // i.e., Architecture::avx512 was not explicitly specified.
    if (arch == Architecture::avx2
        || (archDetected && arch >= Architecture::avx2 && !seqFactory)) {
        if constexpr (std::is_same_v<fp_t, float>) {
#if defined(HPK_HAVE_FFT_AVX2_FP32)
            seqFactory = seq::makeFactory_avx2<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(Architecture::avx2, cfg,
                                                        handles...);
#endif
        }
        if constexpr (std::is_same_v<fp_t, double>) {
#if defined(HPK_HAVE_FFT_AVX2_FP64)
            seqFactory = seq::makeFactory_avx2<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(Architecture::avx2, cfg,
                                                        handles...);
#endif
        }
    }

    if (arch == Architecture::sve256) {
#if HPK_HAVE_FLOAT16_TYPE
        if constexpr (std::is_same_v<fp_t, _Float16>) {
#if defined(HPK_HAVE_FFT_SVE256_FP16)
            seqFactory = seq::makeFactory_sve256<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(arch, cfg, handles...);
#endif
        }
#endif  // HPK_HAVE_FLOAT16_TYPE
        if constexpr (std::is_same_v<fp_t, float>) {
#if defined(HPK_HAVE_FFT_SVE256_FP32)
            seqFactory = seq::makeFactory_sve256<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(arch, cfg, handles...);
#endif
        }
        if constexpr (std::is_same_v<fp_t, double>) {
#if defined(HPK_HAVE_FFT_SVE256_FP64)
            seqFactory = seq::makeFactory_sve256<fp_t, timeType, freqType>(cfg);
#else
            seqFactory = seq::makeFactory_dl<Factory_t>(arch, cfg, handles...);
#endif
        }
    }

    if (cfg.get(Parameter::threads) == 1) return seqFactory;
#if defined(HPK_HAVE_FFT_OMP)
    return omp::makeFactory<Factory_t>(std::move(seqFactory), cfg);
#else
    return omp::makeFactory_dl<Factory_t>(std::move(seqFactory), cfg,
                                          handles...);
#endif
}

}  // namespace fft
}  // namespace hpk

#endif  // HPK_FFT_MAKEFACTORY_HPP_INCLUDED
