/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#include <cassert>
#include <complex>
#include <iostream>
#include <new>  // placement new in examples 4 and 5
#include <string>
#include <vector>

#include <hpk/fft/makeFactory.hpp>

// Prints data (having elements of type 'T', which may be complex) formatted
// into rows and columns, with line continuations if necessary.
template<class T>
void printData(std::string label, const T* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[cols * i + j] << "  ";
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

// Prints data formatted into rows and columns, with line continuations if
// necessary.  It assumes that real and imaginary elements (of type 'fp_t')
// are interleaved in memory.
template<typename fp_t>
void printComplexData(std::string label, const fp_t* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[(2 * cols * i) + (2 * j) + 0] << std::showpos
                      << data[(2 * cols * i) + (2 * j) + 1] << "i  "
                      << std::noshowpos;
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

int main() {
    // Simple introductory example of a one-dimensional two-point FFT
    std::cout << "Example #0: Simple two-point example.\n"
              << "~~~~~~~~~~  \n";
    float twoPoints[4] = {5.0f, 6.0f, 1.0f, 2.0f};
    printComplexData("input", twoPoints, 1, 2);
    // Watch me do the FFT on one line:
    hpk::fft::makeFactory<float>()->makeInplace({2})->forward(twoPoints);
    printComplexData("forward", twoPoints, 1, 2);

    // Let's make factories (for both single precision and double precision)
    // for FFTs with complex time and complex frequency domains.
    std::cout << "Setup: Making factories.\n"
              << "~~~~~  \n";
    // Construct a Configuration object from a one-item initializer list,
    // which specifies that compute objects are to be single threaded.
    // Equivalently, we could have written:
    // hpk::Configuration cfg{{hpk::Parameter::threads, 1}};
    hpk::Configuration cfg{hpk::sequential};
    // Below, auto is std::unique_ptr<hpk::fft::FactoryCC<float>>.
    auto factory_s = hpk::fft::makeFactory<float>(cfg);
    if (factory_s) {
        std::cout << "Using " << *factory_s << " for single precision.\n";
    } else {
        std::cout << "Error: makeFactory<float>() failed" << std::endl;
        return -1;
    }
    // Below, auto is std::unique_ptr<hpk::fft::FactoryCC<double>>.
    auto factory_d = hpk::fft::makeFactory<double>(cfg);
    if (factory_d) {
        std::cout << "Using " << *factory_d << " for double precision.\n";
    } else {
        std::cout << "Error: makeFactory<double>() failed" << std::endl;
        return -1;
    }
    std::cout << '\n';

    // Simple example of a one-dimensional four-point FFT
    std::cout << "Example #1: Simple four-point example.\n"
              << "~~~~~~~~~~  \n";
    std::vector<std::complex<double>> fourPoints = {1.0, 0.0, 0.0, 0.0};
    printData("input", fourPoints.data(), 1, 4);
#if __cpp_lib_ssize
    long ssize = std::ssize(fourPoints);  // C++20
#else
    long ssize = static_cast<long>(std::size(fourPoints));
#endif
    factory_d->makeInplace({ssize})->forward(fourPoints.data());
    printData("forward", fourPoints.data(), 1, 4);

    // One-dimensional eight-point FFT in single precision.
    std::cout << "Example #2: One-dimensional eight-point FFT (batch=1).\n"
              << "~~~~~~~~~~  \n";
    // The points are contiguous in memory, so strides are omitted.
    hpk::fft::InplaceDim layout[1] = {8};
    hpk::fft::InplaceDim batch = 1;
    // Below, auto is std::unique_ptr<hpk::fft::InplaceCC<float>>.
    // And, it's a good idea to state the number of dimensions explicitly,
    // both for readability and so the compiler can check it.
    auto sfft_8 = factory_s->makeInplace<1>(layout, batch);
    assert(sfft_8 && "Error: makeInplace() failed.");
    // No spooky action at a distance; changes to layout and/or batch from
    // now on will not affect sfft_8.
    layout[0].n = 4;  // Perhaps the example after this will be four points.
    std::cout << *sfft_8 << '\n';
    float data[16] = {1.0f, 2.0f};
    printComplexData("input", data, 1, 8);
    // Scratch memory (if needed) will be automatically allocated and deleted.
    sfft_8->forward(data);
    printComplexData("forward", data, 1, 8);

    // One-dimensional four-point FFT in double precision.
    std::cout << "Example #3: One-dimensional four-point FFT (batch=1).\n"
              << "~~~~~~~~~~  \n";
    // Recall that, in the previous example, we've already initialized layout.
    std::cout << "Note: layout has " << std::size(layout)
              << " dimension, which is " << layout[0] << '\n';
    // Note that, below, the batch argument is omitted so it defaults to one.
    std::unique_ptr dfft_4 = factory_d->makeInplace<1>(layout);
    assert(dfft_4 && "Error: makeInplace() failed.");
    {
        std::cout << *dfft_4 << '\n';
        double inout[8] = {1.0, 2.0};
        printComplexData("input", inout, 1, 4);
        // It's more efficient to make a scratch area once and reuse it.
        // The function allocateScratch() allocates memory and returns a
        // std::unique_ptr, so the allocated memory (if any) is deallocated
        // when the smart pointer goes out of scope.
        // Recall that dfft_4 is itself a (smart) pointer and so must be
        // dereferenced to be passed as an argument to allocateScratch().
        // By default, memory is allocated with an hpk::AlignedAllocator,
        // in this case hpk::AlignedAllocator<double>.
        // Below, auto is std::unique_ptr<double, Deleter>, and Deleter is
        // hpk::Deleter<hpk::AlignedAllocator<double>>.
        // The function template name ::hpk::fft::allocateScratch is found
        // by ADL (argument-dependent lookup) and the math type (double) is
        // deduced.
        auto scratch = allocateScratch(*dfft_4);
        dfft_4->forward(inout, scratch);
        printComplexData("forward", inout, 1, 4);
        dfft_4->backward(inout, scratch);
        printComplexData("backward", inout, 1, 4);
    }  // The memory owned by the std::unique_ptr scratch is deleted here.

    // In general, we don't know how much scratch space will be needed, and
    // this may change from one library release to another.
    // However, we can allocate a fixed amount of scratch space on the stack
    // and use it (when it's enough) to avoid dynamic memory allocation.
    // Note that leaving this block scope automatically deallocates stackArea.
    constexpr std::size_t kStackAreaBytes = 8192;
    alignas(64) std::byte stackArea[kStackAreaBytes];

    // Two-dimensional FFT with five rows and eight columns
    std::cout << "Example #4: Two-dimensional 5x8 FFT (batch=1).\n"
              << "~~~~~~~~~~  \n";
    auto sfft_5x8 = factory_s->makeOoplace<2>({5, 8});
    assert(sfft_5x8 && "Error: makeOoplace() failed.");
    std::string fftString = sfft_5x8->toString();
    std::cout << fftString << '\n';
    alignas(64) float input[80] = {1.0f, 2.0f};
    alignas(64) float output[80] = {0.0f, 0.0f};
    printComplexData("input", input, 5, 8);
    std::size_t scratchSize = sfft_5x8->scratchSize();
    // Note that scratchSize() is measured in real (not complex) elements.
    if (scratchSize * sizeof(float) <= kStackAreaBytes) {
        // Below we use a non-allocating placement new.
        float* scratch = new (stackArea) float[scratchSize];
        sfft_5x8->forwardCopy(input, output, scratch);
    } else {
        // There is a function template hpk::allocateMemory(), which is
        // not specific to FFTs, that has a std::size_t parameter for the
        // number of real elements to allocate.  This argument may be 0,
        // in which case an empty std::unique_ptr is returned.
        // Since the function's arguments are C++ basic types, there are
        // no ADL-associated namespaces, and so one needs the "hpk::".
        auto scratch = hpk::allocateMemory<float>(scratchSize);
        sfft_5x8->forwardCopy(input, output, scratch);
    }
    printComplexData("output", output, 5, 8);

    // One-dimensional 60-point FFT in double precision.
    std::cout << "Example #5: One-dimensional 60-point FFT (batch=1).\n"
              << "~~~~~~~~~~  \n";
    // Initialize a shared_ptr if one desires to make copies later.
    std::shared_ptr dfft_60 = factory_d->makeInplace<1>({60});
    assert(dfft_60 && "Error: makeInplace() failed.");
    std::cout << *dfft_60 << '\n';
    // It's a good idea to use a vector having cacheline aligned data.
    std::vector<double, hpk::AlignedAllocator<double>> inout(120, 0.0f);
    inout[0] = 1.0f;
    inout[1] = 2.0f;
    printComplexData("input", inout.data(), 1, 60);
    // Below, scratchSize() returns the number of doubles needed.
    scratchSize = dfft_60->scratchSize();
    // Since we already have scratchSize, it would be faster simply to
    // multiply it by sizeof(double), but for pedagogical reasons:
    if (dfft_60->scratchSizeBytes() <= kStackAreaBytes) {
        // This placement new ends the lifetime of the previous objects
        // in stackArea.  No destructor is called, which is OK for floats.
        double* scratch = new (stackArea) double[scratchSize];
        dfft_60->forward(inout.data(), scratch);
    } else {
        // Scratch memory (with 128B alignment) by specifying an Allocator
        // for the template type parameter of the forward() function.
        // If not specified, hpk::AlignedAllocator<double> would be used.
        // Note that the default alignment for hpk::AlignedAllocator is
        // 64, and this default also applies to allocateScratch().
        dfft_60->forward<hpk::AlignedAllocator<double, 128>>(inout.data());
    }
    printComplexData("forward", inout.data(), 1, 60);
}
