/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#include <cassert>
#include <complex>
#include <iostream>
#include <string>
#include <vector>

#include <hpk/fft/makeFactory.hpp>

// Prints data formatted into rows and columns, with line continuations if
// necessary.  It assumes that real and imaginary elements (of type 'fp_t')
// are interleaved in memory.
template<typename fp_t>
void printComplexData(std::string label, const fp_t* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[(2 * cols * i) + (2 * j) + 0] << std::showpos
                      << data[(2 * cols * i) + (2 * j) + 1] << "i  "
                      << std::noshowpos;
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

// This demonstrates how a factory can be passed to a function, with the
// function taking ownership of the factory.
void compute2x4(std::unique_ptr<hpk::fft::FactoryCC<float>> factory) {
    // This is just a toy problem.  Data layout is 2x4.
    constexpr long kRows = 2;
    constexpr long kCols = 4;

    // First, do the problem using an out-of-place 2D transform.
    // This is the fastest (and easiest) way to go.

    std::cout << "Example #1: Two dimensional, Out-of-place FFT.\n"
              << "~~~~~~~~~~  Rows=" << kRows << ", Cols=" << kCols << '\n';
    const float in[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
                        1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f};
    float out[2 * kRows * kCols];
    printComplexData("in", in, kRows, kCols);

    auto fft_2D = factory->makeOoplace<2>({kRows, kCols});
    assert(fft_2D && "Error: makeOoplace() failed for fft_2D.");
    std::cout << *fft_2D << '\n';
    fft_2D->forwardCopy(in, out);
    printComplexData("out", out, kRows, kCols);

    // Now redo the problem using only 1D unit stride FFTs.

    std::cout << "Example #2: Same problem as in the previous example,\n"
              << "~~~~~~~~~~  but for fun we'll do it the slow way.\n";
    printComplexData("in", in, kRows, kCols);

    auto fft_rows = factory->makeOoplace<1>({kCols}, /*batch=*/kRows);
    assert(fft_rows && "Error: makeOoplace() failed for fft_rows.");
    std::cout << *fft_rows << '\n';
    fft_rows->forwardCopy(in, out);

    std::cout << "Transpose to a temporary array\n";
    float transposed[2 * kRows * kCols];
    for (int i = 0; i < kRows; ++i) {
        for (int j = 0; j < kCols; ++j) {
            transposed[2 * (j * kRows + i) + 0] = out[2 * (i * kCols + j) + 0];
            transposed[2 * (j * kRows + i) + 1] = out[2 * (i * kCols + j) + 1];
        }
    }

    auto fft_colsAsRows = factory->makeInplace<1>({kRows}, /*batch=*/kCols);
    assert(fft_colsAsRows && "Error: makeInplace() failed for fft_colsAsRows.");
    std::cout << *fft_colsAsRows << '\n';
    fft_colsAsRows->forward(transposed);

    std::cout << "Transpose to the out array\n";
    for (int i = 0; i < kRows; ++i) {
        for (int j = 0; j < kCols; ++j) {
            out[2 * (i * kCols + j) + 0] = transposed[2 * (j * kRows + i) + 0];
            out[2 * (i * kCols + j) + 1] = transposed[2 * (j * kRows + i) + 1];
        }
    }

    printComplexData("out", out, kRows, kCols);

    // Now re-do the problem yet again using 1D unit stride FFTs for the rows
    // and 1D strided FFTs for the columns.

    std::cout << "Example #3: Same problem as in the previous example,\n"
              << "~~~~~~~~~~  but avoiding the transposes.\n";
    printComplexData("in", in, kRows, kCols);

    // We will reuse fft_rows, which was made in the previous example.
    // For the columns, we make a strided 1D FFT where the distance between
    // points in the transform is the length of a row, i.e., 2 * kCols.
    // The batch here can be thought of as a SIMD vector of complex numbers.
    // The vector length (the batch size) is kCols, and the stride from one
    // batch element to the next is 2.
    // Recall that strides are measured in terms of real values (floats).
    auto fft_cols = factory->makeInplace<1>({{kRows, 2 * kCols}}, {kCols, 2});
    assert(fft_cols && "Error: makeInplace() failed for fft_cols.");

    // Allocate scratch memory suitable for both FFT compute objects.
    auto scratch = hpk::allocateScratch(*fft_rows, *fft_cols);

    std::cout << *fft_rows << '\n';
    fft_rows->forwardCopy(in, out, scratch);

    std::cout << *fft_cols << '\n';
    fft_cols->forward(out, scratch);

    printComplexData("out", out, kRows, kCols);
}

// This function also takes ownership of the factory.
void compute_2d(std::unique_ptr<hpk::fft::FactoryCC<float>> factory,
                long numRows, long numCols) {
    std::cout << "Example #4: Two dimensional, Inplace FFT.\n"
              << "~~~~~~~~~~  Rows=" << numRows << ", Cols=" << numCols << '\n';

    std::vector<std::complex<float>> v(numRows * numCols);
    const std::complex<float> impulse(1.0f, 2.0f);
    v[0] = impulse;

    auto fft_2D = factory->makeInplace<2>({numRows, numCols});
    assert(fft_2D && "Error: makeInplace() failed for fft_2D.");
    std::cout << "Running forward()... ";
    fft_2D->forward(v.data());
    std::cout << "finished.\n\n";
    assert(v.front() == impulse);
    assert(v.back() == impulse);
}

int main() {
    // Make a factory for complex single precision time and frequency domains
    // that uses only one thread.  Note that hpk::sequential is a convenient
    // name for {hpk::Parameter::threads, 1}.
    auto factory = hpk::fft::makeFactory<float>({hpk::sequential});
    if (factory) {
        std::cout << "Using " << *factory << "\n\n";
    } else {
        std::cout << "Error: single-threaded makeFactory() failed" << std::endl;
        return -1;
    }

    // Transfer ownership of the factory to the compute2x4() function, which
    // will do some FFTs and print the results.
    compute2x4(std::move(factory));
    assert(!factory);

    // Make a factory that uses OpenMP parallelism, which is the default.
    factory = hpk::fft::makeFactory<float>();
    if (factory) {
        std::cout << "Using " << *factory << "\n\n";
    } else {
        std::cout << "Error: parallel makeFactory() failed" << std::endl;
        return -1;
    }

    compute_2d(std::move(factory), 1024, 1024);
    assert(!factory);
}
