/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#include <cassert>
#include <complex>
#include <iostream>
#include <string>
#include <vector>

#include <hpk/fft/makeFactory.hpp>

// Prints data (having elements of type 'T', which may be complex) formatted
// into rows and columns, with line continuations if necessary.
template<class T>
void printData(std::string label, const T* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[cols * i + j] << "  ";
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

// Prints data formatted into rows and columns, with line continuations if
// necessary.  It assumes that real and imaginary elements (of type 'fp_t')
// are interleaved in memory.
template<typename fp_t>
void printComplexData(std::string label, const fp_t* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[(2 * cols * i) + (2 * j) + 0] << std::showpos
                      << data[(2 * cols * i) + (2 * j) + 1] << "i  "
                      << std::noshowpos;
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

int main() {
    // Make a factory for single precision, real-valued time domain, and
    // complex-valued frequency domain.
    auto factory = hpk::fft::makeFactory<float, float, std::complex<float>>();
    if (factory) {
        std::cout << "Using " << *factory << "\n\n";
    } else {
        std::cout << "Error: makeFactory() failed" << std::endl;
        return -1;
    }

    std::cout << "Example #1: Time domain is 4 real points, batch is 2.\n"
              << "~~~~~~~~~~  The last two time domain columns are padding.\n";

    float inout4b2[] = {1.0f, 0.0f, 0.0f, 0.0f, 99.9f, 99.9f,
                        1.0f, 1.0f, 1.0f, 1.0f, 99.9f, 99.9f};
    printData("inout", inout4b2, 2, 6);  // Also print the padding.

    // Below, strides are omitted, so the minimum necessary padding is assumed.
    auto fft = factory->makeInplace<1>({4}, 2);  // 1D, 4 real points, batch 2.
    assert(fft && "Error: makeInplace() failed.");
    std::cout << *fft << "\n\n";
    fft->forward(inout4b2);
    printComplexData("->fwd", inout4b2, 2, 3);  // There is no padding here!
    fft->backward(inout4b2);
    printData("->bwd", inout4b2, 2, 6);  // Also print the padding.

    std::cout << "Example #2: Time domain is 5 real points, batch is 2.\n"
              << "~~~~~~~~~~  We need one padding value in the time domain.\n";

    float inout5b2[] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 99.9f,
                        1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 99.9f};
    printData("inout", inout5b2, 2, 6);

    // Below, strides are omitted, so the minimum necessary padding is assumed.
    fft = factory->makeInplace<1>({5}, 2);  // 1D, 5 real points, batch 2.
    assert(fft && "Error: makeInplace() failed.");
    std::cout << *fft << "\n\n";
    {  // We can use the same scratch area more than once, so let's do so.
        auto scratch = allocateScratch(*fft);
        fft->forward(inout5b2, scratch);
        printComplexData("->fwd", inout5b2, 2, 3);
        fft->backward(inout5b2, scratch);
        printData("->bwd", inout5b2, 2, 6);
    }  // The scratch is destroyed here.  (It was a std::unique_ptr.)

    std::cout << "Example #3: Time domain is 4 real points, batch is 3.\n"
              << "~~~~~~~~~~  Note that extra padding is fine.\n";

    float inout4b3[] = {1.0f, 0.0f, 0.0f, 0.0f, 99.9f, 99.9f, 99.9f, 99.9f,
                        1.0f, 1.0f, 1.0f, 1.0f, 99.9f, 99.9f, 99.9f, 99.9f,
                        2.0f, 1.0f, 1.0f, 0.0f, 99.9f, 99.9f, 99.9f, 99.9f};
    printData("inout", inout4b3, 3, 8);

    fft = factory->makeInplace<1>({4}, {3, 8});  // batch is 3 with stride 8.
    assert(fft && "Error: makeInplace() failed.");
    std::cout << *fft << "\n\n";
    auto scratch = allocateScratch(*fft);
    fft->forward(inout4b3, scratch);
    printComplexData("->fwd", inout4b3, 3, 4);
    fft->backward(inout4b3, scratch);
    printData("->bwd", inout4b3, 3, 8);

    std::cout << "Example #4: Time domain is 4 real points, batch is 2.\n"
              << "~~~~~~~~~~  No padding is needed for out-of-place FFTs.\n";

    const std::vector<float> vinput{2.0f, 0.0f, 0.0f, 0.0f,
                                    2.0f, 2.0f, 2.0f, 2.0f};
    printData("in", vinput.data(), 2, 4);

    // We cannot assign a unique pointer that owns an Ooplace FFT to the
    // variable 'fft', which is a std::unique_ptr<Inplace>.
    auto oopFft = factory->makeOoplace<1>({4}, 2);  // batch 2, default stride.
    assert(oopFft && "Error: makeOoplace() failed.");
    std::cout << *oopFft << "\n\n";

    // We can reassign scratch; doing so frees the previous memory.
    scratch = allocateScratch(*oopFft);

    // The frequency domain is complex.
    std::vector<std::complex<float>> vcmplx(6);  // (batch 2) * (3 points)
    oopFft->forwardCopy(vinput.data(), vcmplx.data(), scratch);
    printData("out", vcmplx.data(), 2, 3);

    // The compute functions also do accept pointers to real types.
    std::vector<float> vfloat(12);  // (batch 2) * (3 complex points) == 12
    oopFft->forwardCopy(vinput.data(), vfloat.data(), scratch);
    printComplexData("out", vfloat.data(), 2, 3);

    const float scale = 0.5f;  // Let's scale by 1/sqrt(4).
    std::vector<std::complex<float>> vscaled(6);
    oopFft->scaleForwardCopy(&scale, vinput.data(), vscaled.data(), scratch);
    printData("scaled", vscaled.data(), 2, 3);
}
