#ifndef KPU_H
#define KPU_H

#include <vector>
#include <string>
#include <fstream>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
#include <nncase/runtime/simple_types.h>
#include "ai2d.h"

/**
 * @brief KPU (Knowledge Processing Unit) wrapper class
 * 
 * Encapsulates nncase kmodel inference related interfaces, providing
 * input/output tensor management, inference execution, and output data access.
 */
class KPU
{
public:
    /**
     * @brief Constructor
     * @param kmodel_file Path to the kmodel file
     */
    explicit KPU(const std::string& kmodel_file);

    /**
     * @brief Destructor
     */
    ~KPU();

    // Disable copy constructor and copy assignment
    KPU(const KPU&) = default;
    KPU& operator=(const KPU&) = default;

    // Enable move constructor and move assignment
    KPU(KPU&&) noexcept = default;
    KPU& operator=(KPU&&) noexcept = default;

    /**
     * @brief Get the number of model inputs
     * @return int Number of inputs
     */
    int get_input_size() noexcept;

    /**
     * @brief Get the number of model outputs
     * @return int Number of outputs
     */
    int get_output_size() noexcept;

    /**
     * @brief Get the shape of a specific input tensor
     * @param idx Input index
     * @return dims_t Dimensions of the input tensor
     */
    dims_t get_input_shape(int idx);

    /**
     * @brief Get the shape of a specific output tensor
     * @param idx Output index
     * @return dims_t Dimensions of the output tensor
     */
    dims_t get_output_shape(int idx);

    /**
     * @brief Get the type of a specific input tensor
     * @param idx Input index
     * @return typecode_t Type of the input tensor
     */
    typecode_t get_input_typecode(int idx);

    /**
     * @brief Get the type of a specific output tensor
     * @param idx Output index
     * @return typecode_t Type of the output tensor
     */
    typecode_t get_output_typecode(int idx);

    /**
     * @brief Get the data size of a specific input tensor
     * @param idx Input index
     * @return size_t Data size of the input tensor
     */
    size_t get_input_data_size(int idx);

    /**
     * @brief Get the data size of a specific output tensor
     * @param idx Output index
     * @return size_t Data size of the output tensor
     */
    size_t get_output_data_size(int idx);

    /**
     * @brief Get the number of bytes of a specific input tensor
     * @param idx Input index
     * @return size_t Byte size of the input tensor
     */
    size_t get_input_data_bytes(int idx);

    /**
     * @brief Get the number of bytes of a specific output tensor
     * @param idx Output index
     * @return size_t Byte size of the output tensor
     */
    size_t get_output_data_bytes(int idx);

    /**
     * @brief Set the data for a specific input tensor
     * @param idx Input index
     * @param input_tensor Input tensor object
     * @throws std::out_of_range Throws if index is out of range
     * @throws std::runtime_error Throws if setting fails
     */
    void set_input_tensor(int idx, const nncase::runtime::runtime_tensor& input_tensor);

    /**
     * @brief Set multiple input tensors
     * @param input_tensors Array of input tensor objects
     * @throws std::invalid_argument Throws if number of tensors does not match model input count
     * @throws std::runtime_error Throws if setting fails
     */
    void set_input_tensors(const std::vector<nncase::runtime::runtime_tensor>& input_tensors);

    /**
     * @brief Get a specific input tensor object
     * @param idx Input index
     * @return nncase::runtime::runtime_tensor Input tensor object
     */
    nncase::runtime::runtime_tensor get_input_tensor(int idx);

    /**
     * @brief Run inference
     * @throws std::runtime_error Throws if inference fails or output tensor retrieval fails
     */
    void run();

    /**
     * @brief Get a specific output tensor object
     * @param idx Output index
     * @return nncase::runtime::runtime_tensor Output tensor object
     */
    nncase::runtime::runtime_tensor get_output_tensor(int idx);

    /**
     * @brief Get the raw data pointer of a specific output tensor
     * @param idx Output index
     * @return char* Pointer to the output tensor data
     */
    char* get_output_ptr(int idx);

    /**
     * @brief Get the raw data of a specific output tensor
     * @param idx Output index
     * @return std::vector<char> Output tensor data
     */
    std::vector<char> get_output_data(int idx);

private:
    std::string kmodel_file_;                             ///< Path to the kmodel file
    nncase::runtime::interpreter kmodel_interp_;          ///< kmodel interpreter object

    int input_size_{0};                                   ///< Number of input tensors
    int output_size_{0};                                  ///< Number of output tensors
    std::vector<dims_t> input_shapes_;                    ///< Shapes of input tensors
    std::vector<dims_t> output_shapes_;                   ///< Shapes of output tensors
    std::vector<typecode_t> input_typecodes_;             ///< Types of input tensors
    std::vector<typecode_t> output_typecodes_;            ///< Types of output tensors
    std::vector<size_t> input_data_size_;                 ///< Data sizes of input tensors
    std::vector<size_t> output_data_size_;                ///< Data sizes of output tensors
    std::vector<size_t> input_data_bytes_;                ///< Byte sizes of input tensors
    std::vector<size_t> output_data_bytes_;               ///< Byte sizes of output tensors
    std::vector<std::vector<char>> output_data_;          ///< Output tensor data
};

#endif // KPU_H
