#include <Arduino.h>

#include <cmath>
#include <cstdio> // For print (which Serial.print might wrap)
#include <cstdlib>
// #include <ctime> // Using <sys/time.h> or similar is often required for embedded clocks
#include <algorithm> // For std::max_element
#include <fstream> // Assuming an embedded system like a Portenta/ESP32 has a file system/mock
#include <numeric> // For std::abs, std::max_element
#include <string>
#include <vector>

#include "k230_fft.h"

using namespace arduino;

#ifndef PI
#define PI 3.14159265358979323846
#endif

// Assuming a sample rate for frequency calculation
constexpr int SAMPLE_RATE = 44100;

// Use pragma pack to ensure the compiler doesn't add padding to our BMP header structs.
#pragma pack(push, 1)
struct BMPFileHeader {
    uint16_t file_type { 0x4D42 }; // "BM"
    uint32_t file_size { 0 };
    uint16_t reserved1 { 0 };
    uint16_t reserved2 { 0 };
    uint32_t offset_data { 0 };
};

struct BMPInfoHeader {
    uint32_t size { 0 };
    int32_t  width { 0 };
    int32_t  height { 0 };
    uint16_t planes { 1 };
    uint16_t bit_count { 24 }; // 24 bits per pixel
    uint32_t compression { 0 };
    uint32_t size_image { 0 };
    int32_t  x_pixels_per_meter { 0 };
    int32_t  y_pixels_per_meter { 0 };
    uint32_t colors_used { 0 };
    uint32_t colors_important { 0 };
};
#pragma pack(pop)

/**
 * @brief Saves a frequency-amplitude plot as a 24-bit BMP image file.
 * @param freqs The vector of frequency data (X-axis).
 * @param amplitudes The vector of amplitude data (Y-axis).
 * @param filename The name of the output BMP file.
 * @param width The width of the image in pixels.
 * @param height The height of the image in pixels.
 */
void save_plot_to_bmp(const std::vector<int>& freqs, const std::vector<int>& amplitudes, const char* filename, int width = 800,
                      int height = 600)
{
    if (freqs.empty() || amplitudes.empty()) {
        Serial.print("Cannot save empty data to BMP.\n"); // Replaced std::cerr/std::endl
        return;
    }

    std::ofstream file(filename, std::ios::out | std::ios::binary);
    if (!file) {
        Serial.printf("Error opening file %s\n", filename); // Replaced std::cerr
        return;
    }

    // --- Prepare pixel data ---
    std::vector<uint8_t> pixel_data(width * height * 3, 255); // Initialize with white background (R=255, G=255, B=255)

    auto set_pixel = [&](int x, int y, uint8_t r, uint8_t g, uint8_t b) {
        if (x >= 0 && x < width && y >= 0 && y < height) {
            int index             = (y * width + x) * 3;
            pixel_data[index + 0] = b; // BMP format uses BGR order
            pixel_data[index + 1] = g;
            pixel_data[index + 2] = r;
        }
    };

    // Find max values for scaling
    int max_freq      = *std::max_element(freqs.begin(), freqs.end());
    int max_amplitude = *std::max_element(amplitudes.begin(), amplitudes.end());
    if (max_amplitude == 0)
        max_amplitude = 1;

    // Draw the data plot (blue line)
    int prev_x = -1, prev_y = -1;
    for (size_t i = 0; i < freqs.size(); ++i) {
        int x = static_cast<int>((static_cast<double>(freqs[i]) / max_freq) * (width - 1));
        // Note: Y-axis is inverted for plotting (0=bottom, height-1=top), which is standard for graphs
        // The BMP header's negative height handles the top-to-bottom pixel storage.
        int y = height - 1 - static_cast<int>((static_cast<double>(amplitudes[i]) / max_amplitude) * (height - 1));

        set_pixel(x, y, 0, 0, 255); // Draw a blue dot

        // Draw a line from the previous point to the current point (Bresenham's line algorithm)
        if (prev_x != -1) {
            int dx = std::abs(x - prev_x), sx = prev_x < x ? 1 : -1;
            int dy = -std::abs(y - prev_y), sy = prev_y < y ? 1 : -1;
            int err       = dx + dy, e2;
            int current_x = prev_x, current_y = prev_y;
            while (true) {
                set_pixel(current_x, current_y, 0, 0, 255);
                if (current_x == x && current_y == y)
                    break;
                e2 = 2 * err;
                if (e2 >= dy) {
                    err += dy;
                    current_x += sx;
                }
                if (e2 <= dx) {
                    err += dx;
                    current_y += sy;
                }
            }
        }
        prev_x = x;
        prev_y = y;
    }

    // --- Write BMP Headers ---
    BMPFileHeader file_header;
    BMPInfoHeader info_header;

    info_header.size  = sizeof(BMPInfoHeader);
    info_header.width = width;
    // Using a negative height to tell the BMP viewer that the data is top-down (simpler)
    info_header.height = -height;

    file_header.offset_data = sizeof(BMPFileHeader) + sizeof(BMPInfoHeader);
    file_header.file_size   = file_header.offset_data + pixel_data.size();

    file.write(reinterpret_cast<char*>(&file_header), sizeof(file_header));
    file.write(reinterpret_cast<char*>(&info_header), sizeof(info_header));
    file.write(reinterpret_cast<const char*>(pixel_data.data()), pixel_data.size());

    file.close();
    Serial.printf("Successfully saved plot to %s\n", filename); // Replaced std::cout
}

// ===================================================================
// UTILITY FUNCTIONS
// ===================================================================

/**
 * @brief Generates a test signal composed of multiple cosine waves (time-domain data).
 * @param fft_num The number of points in the signal.
 * @return A vector containing the real part of the test signal (imaginary is assumed zero).
 */
std::vector<short> build_test_signal(int fft_num)
{
    std::vector<short> rx(fft_num);

    for (int i = 0; i < fft_num; i++) {
        // Signal with frequencies proportional to bins 1, 2, 3, 4, 5
        float tempf1[5];
        tempf1[0] = 10.0f * cosf(1.0f * 2.0f * PI * i / fft_num);
        tempf1[1] = 20.0f * cosf(2.0f * 2.0f * PI * i / fft_num);
        tempf1[2] = 30.0f * cosf(3.0f * 2.0f * PI * i / fft_num);
        tempf1[3] = 0.2f * cosf(4.0f * 2.0f * PI * i / fft_num);
        tempf1[4] = 1000.0f * cosf(5.0f * 2.0f * PI * i / fft_num);

        float signal = tempf1[0] + tempf1[1] + tempf1[2] + tempf1[3] + tempf1[4];

        // Clamp to short range if necessary (though the sum should be fine here)
        rx[i] = static_cast<short>(signal);
    }
    return rx;
}

/**
 * @brief Analyzes and prints the results of the FFT -> IFFT cycle.
 * @param point The number of FFT points.
 * @param original_real The initial real signal.
 * @param ifft_result The final signal after IFFT (interleaved [R, I]).
 * @param fft_time_us Time taken for the FFT operation in microseconds.
 * @param ifft_time_us Time taken for the IFFT operation in microseconds.
 */
static int display_calc_result(int point, const std::vector<short>& original_real, const std::vector<short>& ifft_result,
                               unsigned long long fft_time_us, unsigned long long ifft_time_us)
{
    // The IFFT output is interleaved: [R0, R1, ..., RN-1, I0, I1, ..., IN-1]
    const short* ifft_real = ifft_result.data();
    // The imaginary part of the original signal is 0, so we only check the real part.

    short max_diff_real = 0;

    for (int i = 0; i < point; ++i) {
        // Use std::abs for C++ compatibility and consistency
        short diff_r = std::abs(ifft_real[i] - original_real[i]);

        if (max_diff_real < diff_r) {
            max_diff_real = diff_r;
        }
    }

    Serial.printf("----- FFT/IFFT Point %04d -------\n", point);
    Serial.printf("\tFFT Time: %llu us, IFFT Time: %llu us\n", fft_time_us, ifft_time_us);
    Serial.printf("\tMax reconstruction error (Original R vs IFFT R): %d\n", max_diff_real);

    if (max_diff_real > 5) {
        Serial.printf("----- Result: ERROR (Max difference %d is too high) -----\n\n", max_diff_real);
        return -1;
    } else {
        Serial.print("----- Result: OK -----\n\n");
        return 0;
    }
}

// ===================================================================
// TEST FUNCTION (MAIN LOGIC)
// ===================================================================
void fft_test(int point)
{
    // 1. Prepare time-domain signal (R part only)
    std::vector<short> input_real_vec = build_test_signal(point);

    // --- Timing Start: FFT ---
    // Note: We need a new time point 'ifft_begin' to correctly exclude I/O
    struct timespec begain_time, fft_end, ifft_begin, ifft_end;

    // START TIMING
    clock_gettime(CLOCK_MONOTONIC, &begain_time);

    // 2. Perform FFT
    FFT                f(input_real_vec, point, K230_FFT_Defaults::FFT_SHIFT);
    std::vector<short> fft_result_vec = f.run();

    // END FFT TIMING
    clock_gettime(CLOCK_MONOTONIC, &fft_end);

    // ----------------------------------------------------
    // 3. PLOTTING: Calculate Spectrum and Save to BMP
    //    *** THIS SECTION IS SLOW I/O AND IS NOW EXCLUDED FROM IFFT TIMING ***
    // ----------------------------------------------------

    // Get Frequencies (X-axis)
    std::vector<int32_t> freqs_32 = FFT::calculate_frequencies(point, SAMPLE_RATE);
    std::vector<int>     plot_freqs(freqs_32.begin(), freqs_32.begin() + point / 2);

    // Get Amplitudes (Y-axis)
    std::vector<int32_t> ampl_32 = f.calculate_amplitudes(fft_result_vec);
    std::vector<int>     plot_amplitudes(ampl_32.begin(), ampl_32.begin() + point / 2);

    // Create a unique filename for the BMP output
    std::string filename = "fft_spectrum_" + std::to_string(point) + ".bmp";

    // Call the plotting function (Slow I/O happens here)
    save_plot_to_bmp(plot_freqs, plot_amplitudes, filename.c_str());
    Serial.printf("\tGenerated plot: %s\n", filename.c_str());

    // ----------------------------------------------------
    // 4. Perform IFFT
    // ----------------------------------------------------
    // Prepare IFFT inputs
    std::vector<short> ifft_input_real(fft_result_vec.begin(), fft_result_vec.begin() + point);
    std::vector<short> ifft_input_imag(fft_result_vec.begin() + point, fft_result_vec.end());

    // START IFFT TIMING
    clock_gettime(CLOCK_MONOTONIC, &ifft_begin);

    IFFT               i(ifft_input_real, ifft_input_imag, point, K230_FFT_Defaults::IFFT_SHIFT);
    std::vector<short> ifft_result_vec = i.run();

    // END IFFT TIMING
    clock_gettime(CLOCK_MONOTONIC, &ifft_end);

    // 5. Calculate Timings
    unsigned long long fft_time_us
        = (fft_end.tv_sec - begain_time.tv_sec) * 1000000 + (fft_end.tv_nsec - begain_time.tv_nsec) / 1000;

    // **FIXED CALCULATION:** Time is now measured from ifft_begin to ifft_end
    unsigned long long ifft_time_us
        = (ifft_end.tv_sec - ifft_begin.tv_sec) * 1000000 + (ifft_end.tv_nsec - ifft_begin.tv_nsec) / 1000;

    // 6. Display results
    display_calc_result(point, input_real_vec, ifft_result_vec, fft_time_us, ifft_time_us);
}

// ===================================================================
// SETUP & MAIN
// ===================================================================

void setup()
{
    Serial.begin(115200);
    while (!Serial) { }

    Serial.printf("Starting K230 FFT/IFFT C++ Test Harness (Sample Rate: %d Hz)\n", SAMPLE_RATE);

    fft_test(64);
    fft_test(128);
    fft_test(256);
    fft_test(512);
    fft_test(1024);
    fft_test(2048);
    fft_test(4096);

    Serial.print("All tests complete. Check the generated 'fft_spectrum_N.bmp' files.\n");
}

void loop() { sleep(1); }
