#include "ai2d.h"
#include <stdexcept>

/**
 * @brief Constructor
 */
AI2D::AI2D() = default;

/**
 * @brief Destructor
 */
AI2D::~AI2D() = default;

/**
 * @brief Set AI2D input/output data type and format
 * @param input_format Input data format
 * @param output_format Output data format
 * @param input_dtype Input data type
 * @param output_dtype Output data type
 * @return
 * @throws std::runtime_error Thrown when setting fails
 */
void AI2D::set_ai2d_dtype(ai2d_format input_format,
                          ai2d_format output_format,
                          typecode_t input_dtype,
                          typecode_t output_dtype)
{
    try
    {
        ai2d_data_type_ = { input_format, output_format, input_dtype, output_dtype };
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to set AI2D data type: ") + e.what());
    }
}

/**
 * @brief Set crop parameters
 * @param x Crop start X coordinate
 * @param y Crop start Y coordinate
 * @param width Crop width
 * @param height Crop height
 * @return
 * @throws std::runtime_error Thrown when setting fails
 */
void AI2D::set_crop(size_t x, size_t y, size_t width, size_t height)
{
    try
    {
        ai2d_crop_param_ = { true, static_cast<int32_t>(x), static_cast<int32_t>(y), static_cast<int32_t>(width), static_cast<int32_t>(height) };
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to set crop parameters: ") + e.what());
    }
}

/**
 * @brief Set resize parameters
 * @param interp_method Interpolation method
 * @param interp_mode Interpolation mode
 * @return
 * @throws std::runtime_error Thrown when setting fails
 */
void AI2D::set_resize(ai2d_interp_method interp_method, ai2d_interp_mode interp_mode)
{
    try
    {
        ai2d_resize_param_ = { true, interp_method, interp_mode };
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to set resize parameters: ") + e.what());
    }
}

/**
 * @brief Set padding parameters
 * @param pad Padding amount for each dimension (must contain at least 8 elements)
 * @param pad_mode Padding mode
 * @param pad_value Padding value
 * @return
 * @throws std::invalid_argument Thrown when pad vector has insufficient elements
 * @throws std::runtime_error Thrown when setting fails
 */
void AI2D::set_pad(const vector<int>& pad,
                   ai2d_pad_mode pad_mode,
                   const vector<int>& pad_value)
{
    if (pad.size() < 8)
        throw std::invalid_argument("Pad vector must contain at least 8 elements.");

    try
    {
        ai2d_pad_param_ = { true,
                            { {pad[0], pad[1]},
                              {pad[2], pad[3]},
                              {pad[4], pad[5]},
                              {pad[6], pad[7]} },
                            pad_mode,
                            pad_value };
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to set pad parameters: ") + e.what());
    }
}

/**
 * @brief Set affine transformation parameters
 * @param interp_method Interpolation method
 * @param cord_round Coordinate rounding method
 * @param bound_ind Boundary index type
 * @param bound_val Boundary value
 * @param bound_smooth Boundary smoothing parameter
 * @param affine_matrix Affine matrix (must contain 6 elements)
 * @return
 * @throws std::invalid_argument Thrown when affine_matrix has insufficient elements
 * @throws std::runtime_error Thrown when setting fails
 */
void AI2D::set_affine(ai2d_interp_method interp_method,
                      int cord_round,
                      int bound_ind,
                      int bound_val,
                      int bound_smooth,
                      const vector<float>& affine_matrix)
{
    if (affine_matrix.size() < 6)
        throw std::invalid_argument("Affine matrix must contain at least 6 elements.");

    try
    {
        ai2d_affine_param_ = { true,
                               interp_method,
                               static_cast<uint32_t>(cord_round),
                               static_cast<uint32_t>(bound_ind),
                               static_cast<int32_t>(bound_val),
                               static_cast<uint32_t>(bound_smooth),
                               { affine_matrix[0], affine_matrix[1], affine_matrix[2],
                                 affine_matrix[3], affine_matrix[4], affine_matrix[5] } };
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to set affine parameters: ") + e.what());
    }
}

/**
 * @brief Set shift parameters
 * @param shift_val Shift value
 * @return
 * @throws std::runtime_error Thrown when setting fails
 */
void AI2D::set_shift(int shift_val)
{
    try
    {
        ai2d_shift_param_ = { true, static_cast<int32_t>(shift_val) };
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to set shift parameters: ") + e.what());
    }
}

/**
 * @brief Build AI2D pipeline
 * @param input_shape Input tensor shape
 * @param output_shape Output tensor shape
 * @return
 * @throws std::runtime_error Thrown when build fails
 */
void AI2D::build(const dims_t& input_shape, const dims_t& output_shape)
{
    input_shape_ = input_shape;
    output_shape_ = output_shape;

    try
    {
        
        ai2d_builder_ = std::make_unique<ai2d_builder>(input_shape_,
                                                        output_shape_,
                                                        ai2d_data_type_,
                                                        ai2d_crop_param_,
                                                        ai2d_shift_param_,
                                                        ai2d_pad_param_,
                                                        ai2d_resize_param_,
                                                        ai2d_affine_param_);

        auto ret =ai2d_builder_->build_schedule();
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to build AI2D pipeline: ") + e.what());
    }
}

/**
 * @brief Run AI2D pipeline
 * @param input_tensor Input tensor
 * @param output_tensor Output tensor
 * @return
 * @throws std::runtime_error Thrown when builder is not initialized or execution fails
 */
void AI2D::run(runtime_tensor& input_tensor, runtime_tensor& output_tensor)
{
    if (!ai2d_builder_)
        throw std::runtime_error("AI2D builder is not initialized. Call build() first.");

    try
    {
        ai2d_builder_->invoke(input_tensor,output_tensor).expect("error occurred in ai2d running");
    }
    catch (const std::exception& e)
    {
        throw std::runtime_error(std::string("Failed to run AI2D pipeline: ") + e.what());
    }
}
