/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#include <complex>
#include <iostream>
#include <string>
#include <vector>

#include <hpk/fft/makeFactory.hpp>

// Prints data (having elements of type 'T', which may be complex) formatted
// into rows and columns, with line continuations if necessary.
template<class T>
void printData(std::string label, const T* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[cols * i + j] << "  ";
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

int main() {
    // Let's make factories (for both single precision and double precision)
    // for FFTs with complex time and complex frequency domains.
    std::cout << "Setup: Making factories.\n"
              << "~~~~~  \n";
    // Below, auto is std::unique_ptr<hpk::fft::FactoryCC<float>>.
    auto factory_s = hpk::fft::makeFactory<float>({hpk::sequential});
    if (factory_s) {
        std::cout << "Using " << *factory_s << " for single precision.\n";
    } else {
        std::cout << "Error: makeFactory<float>() failed" << std::endl;
        return -1;
    }
    // Below, auto is std::unique_ptr<hpk::fft::FactoryCC<double>>.
    auto factory_d = hpk::fft::makeFactory<double>({hpk::sequential});
    if (factory_d) {
        std::cout << "Using " << *factory_d << " for double precision.\n";
    } else {
        std::cout << "Error: makeFactory<double>() failed" << std::endl;
        return -1;
    }
    std::cout << '\n';

    // Example of a one-dimensional 12-point FFT in single precision
    std::cout << "Example #1: Twelve-point single precision example.\n"
              << "~~~~~~~~~~  \n";
    std::vector<std::complex<float>> vf(12);
    vf[0] = {1.0f, 2.0f};
    printData("input", vf.data(), 1, 12);
#if __cpp_lib_ssize
    long vfsize = std::ssize(vf);  // C++20
#else
    long vfsize = static_cast<long>(std::size(vf));
#endif
    factory_s->makeInplace({vfsize})->forward(vf.data());
    printData("forward", vf.data(), 1, 12);

    // Example of a one-dimensional 12-point FFT in double precision
    std::cout << "Example #2: Twelve-point double precision example.\n"
              << "~~~~~~~~~~  \n";
    std::vector<std::complex<double>> vd(12);
    vd[0] = {1.0, 2.0};
    printData("input", vd.data(), 1, 12);
#if __cpp_lib_ssize
    long vdsize = std::ssize(vd);  // C++20
#else
    long vdsize = static_cast<long>(std::size(vd));
#endif
    factory_d->makeInplace({vdsize})->forward(vd.data());
    printData("forward", vd.data(), 1, 12);
}
