#include "kpu.h"
#include <iostream>
#include <fstream>
#include <stdexcept>

using nncase::runtime::interpreter;
using nncase::runtime::runtime_tensor;

/**
 * @brief Constructor, load kmodel file and initialize input/output info
 * @param kmodel_file Path to kmodel file
 * @throws std::runtime_error If file open fails, model load fails, or file close fails
 */
KPU::KPU(const std::string& kmodel_file)
    : kmodel_file_(kmodel_file)
{
    std::ifstream ifs(kmodel_file_, std::ios::binary);
    if (!ifs.is_open())
    {
        throw std::runtime_error("Failed to open kmodel file: " + kmodel_file_);
    }

    // Load model
    auto res = kmodel_interp_.load_model(ifs);
    if (!res.is_ok())
    {
        throw std::runtime_error("Failed to load kmodel: " + kmodel_file_);
    }

    ifs.close();
    if (ifs.fail())
    {
        throw std::runtime_error("Failed to close kmodel file: " + kmodel_file_);
    }

    input_size_  = kmodel_interp_.inputs_size();
    output_size_ = kmodel_interp_.outputs_size();

    // Initialize input shapes
    input_shapes_.resize(input_size_);
    input_typecodes_.resize(input_size_);
    input_data_size_.resize(input_size_);
    input_data_bytes_.resize(input_size_);
    for (int i = 0; i < input_size_; ++i)
    {
        auto desc = kmodel_interp_.input_desc(i);
        auto shape_i = kmodel_interp_.input_shape(i);
        auto tensor = host_runtime_tensor::create(desc.datatype, shape_i, hrt::pool_shared).expect("cannot create input tensor");
        kmodel_interp_.input_tensor(i, tensor).expect("cannot set input tensor");
        input_shapes_[i]=shape_i;
        input_typecodes_[i]=desc.datatype;
        input_data_bytes_[i]=desc.size;
        size_t dsize = 1;
        if (desc.datatype == dt_int8 || desc.datatype == dt_uint8 || desc.datatype ==dt_boolean)
        {
            dsize=1;
        }
        else if (desc.datatype == dt_int16 || desc.datatype == dt_uint16 || desc.datatype == dt_float16 || desc.datatype == dt_bfloat16)
        {
            dsize=2;
        }
        else if (desc.datatype == dt_int32 || desc.datatype == dt_uint32 || desc.datatype == dt_float32)
        {
            dsize=4;
        }
        else if(desc.datatype == dt_int64 || desc.datatype == dt_uint64 || desc.datatype == dt_float64)
        {
            dsize=8;
        }
        else
        {
            throw std::runtime_error("unsupported kmodel input data type");
        }
        input_data_size_[i]=desc.size/dsize;
    }

    // Initialize output shapes and pointers
    output_shapes_.resize(output_size_);
    output_typecodes_.resize(output_size_);
    output_data_size_.resize(output_size_);
    output_data_bytes_.resize(output_size_);
    output_data_.resize(output_size_);
    for (int i = 0; i < output_size_; ++i)
    {
        auto desc = kmodel_interp_.output_desc(i);
        auto shape_i = kmodel_interp_.output_shape(i);
        output_shapes_[i]=shape_i;
        output_typecodes_[i]=desc.datatype;
        output_data_bytes_[i]=desc.size;
        size_t dsize = 1;
        if (desc.datatype == dt_int8 || desc.datatype == dt_uint8 || desc.datatype ==dt_boolean)
        {
            dsize=1;
        }
        else if (desc.datatype == dt_int16 || desc.datatype == dt_uint16 || desc.datatype == dt_float16 || desc.datatype == dt_bfloat16)
        {
            dsize=2;
        }
        else if (desc.datatype == dt_int32 || desc.datatype == dt_uint32 || desc.datatype == dt_float32)
        {
            dsize=4;
        }
        else if(desc.datatype == dt_int64 || desc.datatype == dt_uint64 || desc.datatype == dt_float64)
        {
            dsize=8;
        }
        else
        {
            throw std::runtime_error("unsupported kmodel input data type");
        }
        output_data_size_[i]=desc.size/dsize;
    }
}

/**
 * @brief Destructor
 */
KPU::~KPU() = default;

/**
 * @brief Get number of model inputs
 * @return int Number of input tensors
 */
int KPU::get_input_size() noexcept
{
    return input_size_;
}

/**
 * @brief Get number of model outputs
 * @return int Number of output tensors
 */
int KPU::get_output_size() noexcept
{
    return output_size_;
}

/**
 * @brief Get the shape of a specific input tensor
 * @param idx Input index
 * @return dims_t Input tensor dimensions
 * @throws std::out_of_range Thrown if index is out of range
 */
dims_t KPU::get_input_shape(int idx)
{
    if (idx < 0 || idx >= input_size_)
        throw std::out_of_range("Invalid input index");
    return input_shapes_[idx];
}

/**
 * @brief Get the shape of a specific output tensor
 * @param idx Output index
 * @return dims_t Output tensor dimensions
 * @throws std::out_of_range Thrown if index is out of range
 */
dims_t KPU::get_output_shape(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    return output_shapes_[idx];
}

/**
 * @brief Get the type of a specific input tensor
 * @param idx Input index
 * @return typecode_t Input tensor type
 */
typecode_t KPU::get_input_typecode(int idx)
{
    if (idx < 0 || idx >= input_size_)
        throw std::out_of_range("Invalid input index");
    return input_typecodes_[idx];
}

/**
 * @brief Get the type of a specific output tensor
 * @param idx Output index
 * @return typecode_t Output tensor type
 */
typecode_t KPU::get_output_typecode(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    return output_typecodes_[idx];
}

/**
 * @brief Get data size of a specific input tensor
 * @param idx Input index
 * @return size_t Input tensor data size
 */
size_t KPU::get_input_data_size(int idx)
{
    if (idx < 0 || idx >= input_size_)
        throw std::out_of_range("Invalid input index");
    return input_data_size_[idx];
}

/**
 * @brief Get data size of a specific output tensor
 * @param idx Output index
 * @return size_t Output tensor data size
 */
size_t KPU::get_output_data_size(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    return output_data_size_[idx];
}

/**
 * @brief Get number of bytes of a specific input tensor
 * @param idx Input index
 * @return size_t Input tensor bytes
 */
size_t KPU::get_input_data_bytes(int idx)
{
    if (idx < 0 || idx >= input_size_)
        throw std::out_of_range("Invalid input index");
    return input_data_bytes_[idx];
}

/**
 * @brief Get number of bytes of a specific output tensor
 * @param idx Output index
 * @return size_t Output tensor bytes
 */
size_t KPU::get_output_data_bytes(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    return output_data_bytes_[idx];
}

/**
 * @brief Set a specific input tensor
 * @param idx Input index
 * @param input_tensor Input tensor object
 * @throws std::out_of_range Thrown if index is out of range
 * @throws std::runtime_error Thrown if set fails
 */
void KPU::set_input_tensor(int idx, const runtime_tensor& input_tensor)
{
    if (idx < 0 || idx >= input_size_)
        throw std::out_of_range("Invalid input index");
    auto res = kmodel_interp_.input_tensor(idx, input_tensor);
    if (!res.is_ok())
        throw std::runtime_error("Failed to set input tensor");
}

/**
 * @brief Set multiple input tensors
 * @param input_tensors Input tensor objects array
 * @throws std::invalid_argument Thrown if number of input tensors does not match
 * @throws std::runtime_error Thrown if set fails
 */
void KPU::set_input_tensors(const std::vector<nncase::runtime::runtime_tensor>& input_tensors)
{
    if (input_tensors.size() != input_size_)
        throw std::invalid_argument("Input tensor size mismatch");
    for (int i = 0; i < input_size_; ++i)
    {
        auto res = kmodel_interp_.input_tensor(i, input_tensors[i]);
        if (!res.is_ok())
            throw std::runtime_error("Failed to set input tensor");
    }
}

/**
 * @brief Get a specific input tensor object
 * @param idx Input index
 * @return runtime_tensor Input tensor object
 * @throws std::out_of_range Thrown if index is out of range
 * @throws std::runtime_error Thrown if get fails
 */
runtime_tensor KPU::get_input_tensor(int idx)
{
    if (idx < 0 || idx >= input_size_)
        throw std::out_of_range("Invalid input index");
    auto res = kmodel_interp_.input_tensor(idx).expect("cannot get input tensor");
    return res;
}

/**
 * @brief Run kmodel inference
 * @throws std::runtime_error Thrown if run fails or output tensor get fails
 */
void KPU::run()
{
    auto res = kmodel_interp_.run();
    if (!res.is_ok())
        throw std::runtime_error("Failed to run kmodel");
    output_data_.assign(output_size_, std::vector<char>());
    // Map output tensor pointers
    for (int i = 0; i < output_size_; ++i)
    {
        auto res = kmodel_interp_.output_tensor(i).expect("cannot get output tensor");
        auto output_buf = res.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_read).unwrap().buffer();
        // Copy output tensor data
        output_data_[i].resize(output_data_bytes_[i]);
        memcpy(output_data_[i].data(), output_buf.data(), output_data_bytes_[i]);
    }
}

/**
 * @brief Get a specific output tensor object
 * @param idx Output index
 * @return runtime_tensor Output tensor object
 * @throws std::out_of_range Thrown if index is out of range
 * @throws std::runtime_error Thrown if get fails
 */
runtime_tensor KPU::get_output_tensor(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    auto res = kmodel_interp_.output_tensor(idx).expect("cannot get output tensor");
    return res;
}

/**
 * @brief Get raw data pointer of a specific output tensor
 * @param idx Output index
 * @return char* Output data pointer
 * @throws std::out_of_range Thrown if index is out of range
 * @throws std::runtime_error Thrown if output tensor not mapped
 */
char* KPU::get_output_ptr(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    if (output_data_[idx].empty())
        throw std::runtime_error("Output tensor not mapped");
    return output_data_[idx].data();
}

/**
 * @brief Get data of a specific output tensor
 * @param idx Output index
 * @return std::vector<char> Output tensor data
 * @throws std::out_of_range Thrown if index is out of range
 */
std::vector<char> KPU::get_output_data(int idx)
{
    if (idx < 0 || idx >= output_size_)
        throw std::out_of_range("Invalid output index");
    if (output_data_[idx].empty())
        throw std::runtime_error("Output tensor not mapped");
    return output_data_[idx];
}
