/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\
*  Copyright (C) 2023--2026, High Performance Kernels LLC                     *
*                                                                             *
*  This software and the related documents are High Performance Kernels LLC   *
*  copyrighted materials, and your use of them is governed by the express     *
*  license under which they were provided to you (License).                   *
*  Unless the License provides otherwise, you may not use, copy, reproduce,   *
*  modify, disclose, transmit, publish, or distribute this software or the    *
*  related documents without prior written permission from High Performance   *
*  Kernels LLC.                                                               *
*                                                                             *
*    This software and the related documents are provided as is, WITHOUT ANY  *
*  WARRANTY, without even the implied warranty of MERCHANTABILITY or FITNESS  *
*  FOR A PARTICULAR PURPOSE.                                                  *
\* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#include <cassert>
#include <iostream>
#include <string>
#include <vector>

#include <hpk/fft/makeFactory.hpp>

// Prints data formatted into rows and columns, with line continuations if
// necessary.  It assumes that real and imaginary elements (of type 'fp_t')
// are interleaved in memory.
template<typename fp_t>
void printComplexData(std::string label, const fp_t* data, int rows, int cols) {
    std::cout << label << ":\n";
    for (int i = 0; i < rows; ++i) {
        std::cout << "        ";
        for (int j = 0; j < cols; ++j) {
            if (j % 8 == 0 && j > 0) std::cout << " \\\n        ";
            std::cout << data[(2 * cols * i) + (2 * j) + 0] << std::showpos
                      << data[(2 * cols * i) + (2 * j) + 1] << "i  "
                      << std::noshowpos;
        }
        std::cout << '\n';
    }
    std::cout << std::endl;
}

// This demonstrates how a factory can be passed to a function, with the
// function taking ownership of the factory.
void compute3x6(std::unique_ptr<hpk::fft::FactoryCC<float>> factory) {
    // This is just a toy problem.  Data layout is 3x6.
    constexpr long kRows = 3;
    constexpr long kCols = 6;

    std::cout << "Example #1: Two dimensional, In-place FFT.\n"
              << "~~~~~~~~~~  Rows=" << kRows << ", Cols=" << kCols << '\n';
    std::vector<float> v(2 * kCols * kRows);
    v[0] = 1.0f;
    v[1] = 2.0f;
    printComplexData("input", v.data(), kRows, kCols);

    auto fft = factory->makeInplace<2>({kRows, kCols});
    assert(fft && "Error: makeInplace() failed for fft.");
    std::cout << *fft << '\n';
    fft->forward(v.data());
    printComplexData("forward", v.data(), kRows, kCols);
}

int main() {
    // For demonstration purposes, this program will NOT be linked with the
    // shared library libhpk_fft_avx512_fp32.so.
    // So, on AVX512 hardware, let's try to load it with dlopen().
    // If this succeeds, makeFactory() will make an avx512 factory.
    void* handle = RTLD_DEFAULT;
    if (hpk::detectArchitecture() >= hpk::Architecture::avx512) {
        handle = dlopen(hpk::fft::avx512_fp32_so, RTLD_LAZY);
        if (!handle) {
            std::cout << "Warning: dlopen(" << hpk::fft::avx512_fp32_so
                      << ") failed.\n";
            handle = RTLD_DEFAULT;
        }
    }

    // Make a factory for complex single precision time and frequency domains
    // that uses only one thread, passing the handle from above.
    auto factory = hpk::fft::makeFactory<float>({hpk::sequential}, handle);
    if (factory) {
        std::cout << "Using " << *factory << "\n\n";
    } else {
        std::cout << "Error: makeFactory<float>() failed" << std::endl;
        return -1;
    }

    // Transfer ownership of the factory to the compute3x6() function, which
    // will do some FFTs and print the results.
    compute3x6(std::move(factory));

    // Note that the previous function takes the factory by value (not by
    // reference), so the factory is destroyed when the unique_ptr goes out
    // of scope at the end of the compute3x6() function.
    assert(!factory);
    // Furthermore, by examining compute3x6(), we see that all the FFT objects
    // that the factory created have also been destroyed since their owning
    // unique_ptrs have gone out of scope.
    // Therefore, at this point, destructors have completed for all objects
    // that need hpk::fft::avx512_fp32_so (for destruction or anything else).
    // So, the library may be unloaded.
    if (handle != RTLD_DEFAULT) dlclose(handle);
}
