#include "k230_fft.h"
#include <cmath>
#include <numeric> // For std::log2 in C++20, but we'll use cmath for broader compatibility

// Put implementation details in the .cpp file
namespace { // Anonymous namespace to limit linkage to this file
int count_set_bits(int n)
{
    int count = 0;
    while (n > 0) {
        n &= (n - 1);
        count++;
    }
    return count;
}
}

// ============== FFT Implementation ==============

FFT::FFT(const std::vector<int16_t>& real_input, int point, int shift)
    : m_input_real(real_input) // Copy input data
    , m_point(point)
    , m_shift(shift)
{
    check_point(m_point);
    if (m_input_real.size() != m_point) {
        throw std::invalid_argument("Input data size must match FFT point count.");
    }
    // Resize the output buffer to hold both real and imaginary parts
    m_output_buffer.resize(2 * m_point);
}

// Using a helper function instead of a macro
void FFT::check_point(int point)
{
    switch (point) {
    case 64:
    case 128:
    case 256:
    case 512:
    case 1024:
    case 2048:
    case 4096:
        return; // Valid point
    default:
        throw std::runtime_error("Invalid FFT point. Must be a power of 2 from 64 to 4096.");
    }
}

std::vector<int16_t> FFT::run()
{
    // We pass the raw pointer to the C API using .data()
    int result = kd_mpi_fft(m_point, RRRR, RR_II_OUT, 0, m_shift, m_input_real.data(),
                            m_input_real.data(), // Second real input for RRRR mode
                            m_output_buffer.data(),
                            m_output_buffer.data() + m_point // Pointer to the imaginary part start
    );

    if (result != 0) {
        throw std::runtime_error("kd_mpi_fft execution failed.");
    }

    // Return a copy of the result. This is safe for the caller.
    return m_output_buffer;
}

std::vector<int32_t> FFT::calculate_frequencies(int point, int sample_rate)
{
    // This function is now static because it doesn't depend on instance members
    std::vector<int32_t> freqs(point);
    float                step = static_cast<float>(sample_rate) / static_cast<float>(point);

    for (int i = 0; i < point; ++i) {
        freqs[i] = static_cast<int32_t>(step * i);
    }
    return freqs;
}

std::vector<int32_t> FFT::calculate_amplitudes(const std::vector<int16_t>& complex_data)
{
    if (complex_data.size() != 2 * m_point) {
        throw std::invalid_argument("Complex data size is incorrect for amplitude calculation.");
    }

    std::vector<int32_t> amplitudes(m_point);
    const int            num_one = count_set_bits(m_shift);

    const int16_t* real_part = complex_data.data();
    const int16_t* imag_part = complex_data.data() + m_point;

    for (int i = 0; i < m_point; ++i) {
        // Use int32_t for intermediate calculations to prevent overflow
        int32_t r_val = real_part[i];
        int32_t i_val = imag_part[i];

        uint32_t amplitude = static_cast<uint32_t>(std::sqrt(static_cast<double>(r_val * r_val) + (i_val * i_val)));

        uint32_t hard_power = (i == 0) ? (amplitude / m_point) : (2 * amplitude / m_point);

        amplitudes[i] = hard_power << num_one;
    }
    return amplitudes;
}

std::vector<int32_t> FFT::calculate_amplitudes() { return calculate_amplitudes(m_output_buffer); }

// ============== IFFT Implementation ==============

IFFT::IFFT(const std::vector<int16_t>& real_input, const std::vector<int16_t>& imag_input, int point, int shift)
    : FFT(real_input, point, shift) // Call base class constructor
    , m_input_imag(imag_input)
{
    if (m_input_imag.size() != m_point) {
        throw std::invalid_argument("Imaginary input data size must match IFFT point count.");
    }
}

std::vector<int16_t> IFFT::run()
{
    int result = kd_mpi_ifft(m_point, RIRI, RR_II_OUT, 0, m_shift, m_input_real.data(), m_input_imag.data(),
                             m_output_buffer.data(), m_output_buffer.data() + m_point);

    if (result != 0) {
        throw std::runtime_error("kd_mpi_ifft execution failed.");
    }

    return m_output_buffer;
}
