#include <iostream>
#include <fstream>
#include <cmath>
#include <kpu.h>

/**
 * @file test_kpu.ino
 * @brief Arduino/KPU test program: loads a model, runs inference on input.bin, 
 *        compares output with output.bin using cosine similarity.
 */

// File paths for KPU model and input/output binary data
std::string kmodel_path = "test.kmodel";
std::string input_bin   = "input.bin";
std::string output_bin  = "output.bin";

/**
 * @brief Compute cosine similarity between two float vectors.
 * @param a Pointer to first vector
 * @param b Pointer to second vector
 * @param length Number of elements in each vector
 * @return Cosine similarity value in range [-1.0, 1.0], 0 if any vector is zero
 */
float cosine_similarity(const float* a, const float* b, size_t length) {
    double dot = 0.0;
    double norm_a = 0.0;
    double norm_b = 0.0;

    for (size_t i = 0; i < length; i++) {
        dot    += a[i] * b[i];     /**< Dot product */
        norm_a += a[i] * a[i];     /**< Norm of vector a */
        norm_b += b[i] * b[i];     /**< Norm of vector b */
    }

    if (norm_a == 0 || norm_b == 0) {
        return 0.0f;               /**< Avoid division by zero */
    }

    return static_cast<float>(dot / (std::sqrt(norm_a) * std::sqrt(norm_b)));
}

/**
 * @brief Arduino setup function: runs once at startup
 */
void setup() {
    /** Load input binary file into memory */
    std::ifstream ifs_input(input_bin, std::ios::binary);
    ifs_input.seekg(0, ifs_input.end);
    size_t len_input = ifs_input.tellg();
    char* input_bin_data = new char[len_input];
    ifs_input.seekg(0, ifs_input.beg);
    ifs_input.read(input_bin_data, len_input);
    ifs_input.close();

    /** Load expected output binary file into memory */
    std::ifstream ifs_output(output_bin, std::ios::binary);
    ifs_output.seekg(0, ifs_output.end);
    size_t len_output = ifs_output.tellg();
    size_t num_floats = len_output / sizeof(float);
    float* output_bin_data = new float[num_floats];
    ifs_output.seekg(0, ifs_output.beg);
    ifs_output.read(reinterpret_cast<char*>(output_bin_data), len_output);
    ifs_output.close();

    /** Initialize Kmodel */
    KPU kpu(kmodel_path);
    int input_size  = kpu.get_input_size();   /**< Number of input tensors */
    int output_size = kpu.get_output_size();  /**< Number of output tensors */
    printf("input size: %d\n", input_size);
    printf("output size: %d\n", output_size);

    /** Print input tensor shapes */
    for (int i = 0; i < input_size; i++) {
        printf("input %d shape: ", i);
        dims_t shape_i = kpu.get_input_shape(i);
        for (size_t j = 0; j < shape_i.size(); j++) {
            printf("%zu ", shape_i[j]);
        }
        printf("\n");
    }

    /** Print output tensor shapes */
    for (int i = 0; i < output_size; i++) {
        printf("output %d shape: ", i);
        dims_t shape_i = kpu.get_output_shape(i);
        for (size_t j = 0; j < shape_i.size(); j++) {
            printf("%zu ", shape_i[j]);
        }
        printf("\n");
    }

    /** Get input type and size information */
    typecode_t input_typecode = kpu.get_input_typecode(0);
    printf("input typecode:%d\n", input_typecode);
    dims_t input_shape = kpu.get_input_shape(0);
    size_t input_data_size = kpu.get_input_data_size(0);
    printf("input data size: %d\n", input_data_size);
    size_t input_data_bytes = kpu.get_input_data_bytes(0);
    printf("input data bytes: %d\n", input_data_bytes);

    /** Create runtime tensor and copy input data */
    runtime_tensor input_tensor = host_runtime_tensor::create(
        input_typecode,
        input_shape,
        { (gsl::byte *)input_bin_data, input_data_bytes },
        true,
        hrt::pool_shared
    ).expect("cannot create input tensor");
    hrt::sync(input_tensor, sync_op_t::sync_write_back, true).unwrap();

    /** Set input tensor and run KPU inference */
    kpu.set_input_tensor(0, input_tensor);
    kpu.run();

    /** Get output type and size information */
    typecode_t output_typecode = kpu.get_output_typecode(0);
    printf("output typecode:%d\n", output_typecode);
    dims_t output_shape = kpu.get_output_shape(0);
    size_t output_data_size = kpu.get_output_data_size(0);
    printf("output data size: %d\n", output_data_size);
    size_t output_data_bytes = kpu.get_output_data_bytes(0);
    printf("output data bytes: %d\n", output_data_bytes);

    /** Retrieve KPU output data */
    std::vector<char> output_kmodel_data = kpu.get_output_data(0);

    /** Compute cosine similarity between KPU output and expected output */
    float sim = cosine_similarity(
        reinterpret_cast<float *>(output_kmodel_data.data()),
        output_bin_data,
        kpu.get_output_data_size(0)
    );
    std::cout << "cosine similarity: " << sim << std::endl;
}

/**
 * @brief Arduino loop function: runs repeatedly after setup (empty here)
 */
void loop() {
    // Nothing to do in loop for this test program
}
