#include <iostream>
#include "utils.h"

using std::ofstream;
using std::vector;

/**
 * @brief Get color list for classes by cycling through predefined color palette.
 * 
 * @param num_classes Number of object classes.
 * @return std::vector<cv::Scalar> Color list corresponding to each class.
 */
std::vector<cv::Scalar> get_colors_for_classes(int num_classes) {
    std::vector<cv::Scalar> colors;
    int num_available_colors = color_four.size(); 
    for (int i = 0; i < num_classes; ++i) {
        colors.push_back(color_four[i % num_available_colors]);
    }
    return colors;
}

/**
 * @brief Read label names from a text file (one label per line).
 * 
 * @param labels_txt_path Path to the label text file.
 * @return std::vector<std::string> Vector of label strings.
 */
std::vector<std::string> read_labels_from_txt(std::string labels_txt_path) {
    std::vector<std::string> labels;
    std::ifstream file(labels_txt_path);
    if (!file.is_open()) {
        std::cerr << "Error: Could not open file " << labels_txt_path << std::endl;
        return labels;
    }

    std::string line;
    while (std::getline(file, line)) {
        if (!line.empty()) {
            // Remove trailing '\r' for Windows CRLF compatibility
            if (!line.empty() && line.back() == '\r') {
                line.pop_back();
            }
            labels.push_back(line);
        }
    }
    return labels;
}

/**
 * @brief Convert a cv::Mat image to a runtime_tensor in CHW or NHWC format.
 * 
 * @param ori_img Input image as a cv::Mat.
 * @param mode Conversion mode: "CHW" for channel-first format, "HWC" for channel-last format.
 * @return runtime_tensor Corresponding tensor with image data.
 */
runtime_tensor mat_to_tensor(cv::Mat &ori_img, std::string mode) {
    runtime_tensor image_tensor;
    try {
        size_t H = ori_img.rows;
        size_t W = ori_img.cols;
        size_t C = ori_img.channels();

        if (mode == "CHW") {
            std::vector<uint8_t> chw_vec(C * H * W);
            uint8_t *ptr = chw_vec.data();
            std::vector<cv::Mat> channels(C);
            cv::split(ori_img, channels);
            for (size_t c = 0; c < C; ++c) {
                const uint8_t *row_ptr;
                for (int i = 0; i < H; ++i) {
                    row_ptr = channels[C - 1 - c].ptr<uint8_t>(i);
                    memcpy(ptr, row_ptr, W);
                    ptr += W;
                }
            }

            dims_t image_shape {1, C, H, W};
            image_tensor = host_runtime_tensor::create(
                typecode_t::dt_uint8,
                image_shape,
                { reinterpret_cast<gsl::byte*>(chw_vec.data()), C * H * W },
                true,
                hrt::pool_shared
            ).expect("cannot create input tensor");
        } else if (mode == "HWC") {
            dims_t image_shape {1, H, W, C};
            image_tensor = host_runtime_tensor::create(
                typecode_t::dt_uint8,
                image_shape,
                { reinterpret_cast<gsl::byte*>(ori_img.data), H * W * C },
                true,
                hrt::pool_shared
            ).expect("cannot create input tensor");
        }
        else {
            throw std::runtime_error("Error: Invalid mode. Use 'CHW' or 'HWC'.");
        }
        hrt::sync(image_tensor, sync_op_t::sync_write_back, true).expect("write back input failed");
    } catch (std::exception &e) {
        throw std::runtime_error(std::string("Error: Failed to convert image to tensor: ") + e.what());
    }
    return image_tensor;
}


/**
 * @brief Convert a k_video_frame_info to a runtime_tensor in CHW or NHWC format.
 * 
 * @param videoframe Input video frame information.
 * @param mode Conversion mode: "CHW" for channel-first format, "HWC" for channel-last format.
 * @return runtime_tensor Corresponding tensor with image data.
 */
runtime_tensor videoframe_to_tensor(k_video_frame_info &videoframe) {
    runtime_tensor image_tensor;
    if(videoframe.v_frame.pixel_format == PIXEL_FORMAT_RGB_888_PLANAR) {
        dims_t input_shape{1,3,videoframe.v_frame.height,videoframe.v_frame.width};
        uintptr_t virt_addr=reinterpret_cast<uintptr_t>(kd_mpi_sys_mmap(videoframe.v_frame.phys_addr[0], input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3]));
        image_tensor = host_runtime_tensor::create(
            typecode_t::dt_uint8,
            input_shape,
            { reinterpret_cast<gsl::byte*>(virt_addr), input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3] },
            true,
            hrt::pool_shared
        ).expect("cannot create input tensor");
        hrt::sync(image_tensor, sync_op_t::sync_write_back, true).expect("write back input failed");
        kd_mpi_sys_munmap(reinterpret_cast<void*>(virt_addr), input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3]);
    }
    else if(videoframe.v_frame.pixel_format == PIXEL_FORMAT_RGB_888) {
        dims_t input_shape{1,videoframe.v_frame.height,videoframe.v_frame.width,3};
        uintptr_t virt_addr=reinterpret_cast<uintptr_t>(kd_mpi_sys_mmap(videoframe.v_frame.phys_addr[0], input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3]));
        image_tensor = host_runtime_tensor::create(
            typecode_t::dt_uint8,
            input_shape,
            { reinterpret_cast<gsl::byte*>(virt_addr), input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3] },
            true,
            hrt::pool_shared
        ).expect("cannot create input tensor");
        hrt::sync(image_tensor, sync_op_t::sync_write_back, true).expect("write back input failed");
        kd_mpi_sys_munmap(reinterpret_cast<void*>(virt_addr), input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3]);
    }
    else {
        throw std::runtime_error("Error: Invalid pixel format.");
    }
    return image_tensor;
}


/**
 * @brief Perform Non-Maximum Suppression (NMS) on axis-aligned bounding boxes.
 * 
 * @param bboxes Input and output bounding box list.
 * @param confThreshold Confidence threshold to filter boxes.
 * @param nmsThreshold IoU threshold for suppression.
 * @param indices Output indices of retained boxes.
 */
void nms(std::vector<YOLOBbox> &bboxes,  float confThreshold, float nmsThreshold, std::vector<int> &indices)
{	
    std::sort(bboxes.begin(), bboxes.end(), [](YOLOBbox &a, YOLOBbox &b) { return a.confidence > b.confidence; });
    int updated_size = bboxes.size();
    for (int i = 0; i < updated_size; i++) {
        if (bboxes[i].confidence < confThreshold)
            continue;
        indices.push_back(i);
        for (int j = i + 1; j < updated_size;) {
            float iou = iou_calculate(bboxes[i].box, bboxes[j].box);
            if (iou > nmsThreshold) {
                bboxes[j].confidence = -1;  // Mark invalid detections
            }
            j++;
        }
    }
    bboxes.erase(std::remove_if(bboxes.begin(), bboxes.end(), [](YOLOBbox &b) { return b.confidence < 0; }), bboxes.end());
}

/**
 * @brief Calculate Intersection over Union (IoU) between two rectangles.
 * 
 * @param rect1 First rectangle.
 * @param rect2 Second rectangle.
 * @return float IoU value.
 */
float iou_calculate(cv::Rect &rect1, cv::Rect &rect2)
{
    int xx1, yy1, xx2, yy2;
 
	xx1 = std::max(rect1.x, rect2.x);
	yy1 = std::max(rect1.y, rect2.y);
	xx2 = std::min(rect1.x + rect1.width - 1, rect2.x + rect2.width - 1);
	yy2 = std::min(rect1.y + rect1.height - 1, rect2.y + rect2.height - 1);
 
	int insection_width, insection_height;
	insection_width = std::max(0, xx2 - xx1 + 1);
	insection_height = std::max(0, yy2 - yy1 + 1);
 
	float insection_area, union_area, iou;
	insection_area = float(insection_width) * insection_height;
	union_area = float(rect1.width*rect1.height + rect2.width*rect2.height - insection_area);
	iou = insection_area / union_area;
	return iou;
}

/**
 * @brief Fast exponential approximation.
 * 
 * @param x Input value.
 * @return float Approximated exponential of x.
 */
float fast_exp(float x)
{
    union {
        uint32_t i;
        float f;
    } v{};
    v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
    return v.f;
}

/**
 * @brief Compute sigmoid activation function.
 * 
 * @param x Input value.
 * @return float Sigmoid result.
 */
float sigmoid(float x)
{
    return 1.0f / (1.0f + fast_exp(-x));
}

/**
 * @brief Clamp a value within a specified range.
 * 
 * @tparam T Type of the value.
 * @param value Input value.
 * @param low Minimum limit.
 * @param high Maximum limit.
 * @return T Clamped value.
 */
template<typename T>
T def_clamp(T value, T low, T high) {
    return (value < low) ? low : (value > high) ? high : value;
}

/**
 * @brief Compute covariance matrix parameters for an oriented bounding box (OBB).
 * 
 * @param obb Input oriented bounding box.
 * @return std::array<float, 3> Covariance matrix components {a, b, c}.
 */
std::array<float, 3> get_covariance_matrix(YOLOBbox& obb) {
    float width = obb.box.width / 2.0f;
    float height = obb.box.height / 2.0f;
    float angle = obb.angle;

    float cos_angle = std::cos(angle);
    float sin_angle = std::sin(angle);

    float a = std::pow(width * cos_angle, 2) + std::pow(height * sin_angle, 2);
    float b = std::pow(width * sin_angle, 2) + std::pow(height * cos_angle, 2);
    float c = width * cos_angle * height * sin_angle;

    return {a, b, c};
}

/**
 * @brief Calculate IoU for rotated bounding boxes using Gaussian Wasserstein distance.
 * 
 * @param obb1 First oriented bounding box.
 * @param obb2 Second oriented bounding box.
 * @param eps Small constant to prevent division by zero.
 * @return float Rotated IoU value.
 */
float iou_rotate_calculate(YOLOBbox& obb1,YOLOBbox& obb2,float eps) {
    float x1 = obb1.box.x, y1 = obb1.box.y;
    float x2 = obb2.box.x, y2 = obb2.box.y;

    auto [a1, b1, c1] = get_covariance_matrix(obb1);
    auto [a2, b2, c2] = get_covariance_matrix(obb2);

    float denom = (a1 + a2) * (b1 + b2) - std::pow(c1 + c2, 2) + eps;

    float t1 = ((a1 + a2) * std::pow(y1 - y2, 2) + (b1 + b2) * std::pow(x1 - x2, 2)) / denom * 0.25f;
    float t2 = ((c1 + c2) * (x2 - x1) * (y1 - y2)) / denom * 0.5f;

    float numer = (a1 + a2) * (b1 + b2) - std::pow(c1 + c2, 2);
    float denom_log = 4.0f * std::sqrt((a1 * b1 - c1 * c1) * (a2 * b2 - c2 * c2)) + eps;

    float t3 = 0.5f * std::log(numer / denom_log + eps);

    float bd = def_clamp(t1 + t2 + t3, eps, 100.0f);
    float hd = std::sqrt(1.0f - std::exp(-bd) + eps);

    return 1.0f - hd;
}

/**
 * @brief Calculate four corner points of a rotated bounding box.
 * 
 * @param x_center Center x-coordinate.
 * @param y_center Center y-coordinate.
 * @param width Box width.
 * @param height Box height.
 * @param angle Rotation angle in radians.
 * @return std::vector<std::pair<int, int>> List of (x, y) corner coordinates.
 */
std::vector<std::pair<int, int>> calculate_obb_corners(float x_center, float y_center, float width, float height, float angle) {
    float cos_angle = std::cos(angle);  
    float sin_angle = std::sin(angle);  
    float dx = width / 2.0f;
    float dy = height / 2.0f;

    std::vector<std::pair<int, int>> corners = {
        { static_cast<int>(x_center + cos_angle * dx - sin_angle * dy),
          static_cast<int>(y_center + sin_angle * dx + cos_angle * dy) },

        { static_cast<int>(x_center - cos_angle * dx - sin_angle * dy),
          static_cast<int>(y_center - sin_angle * dx + cos_angle * dy) },

        { static_cast<int>(x_center - cos_angle * dx + sin_angle * dy),
          static_cast<int>(y_center - sin_angle * dx - cos_angle * dy) },

        { static_cast<int>(x_center + cos_angle * dx + sin_angle * dy),
          static_cast<int>(y_center + sin_angle * dx - cos_angle * dy) }
    };

    return corners;
}

/**
 * @brief Perform Non-Maximum Suppression (NMS) for rotated bounding boxes (OBB).
 * 
 * @param bboxes Input and output oriented bounding box list.
 * @param confThreshold Confidence threshold for filtering.
 * @param nmsThreshold IoU threshold for suppression.
 * @param indices Output indices of retained boxes.
 */
void rotate_nms(std::vector<YOLOBbox> &bboxes, float confThreshold, float nmsThreshold,std::vector<int> &indices)
{
    std::sort(bboxes.begin(), bboxes.end(), [](const YOLOBbox &a, const YOLOBbox &b) { return a.confidence > b.confidence; });

    int updated_size = bboxes.size();
    for (int i = 0; i < updated_size; i++) {
        if (bboxes[i].confidence < confThreshold)
            continue;
        indices.push_back(i);
        for (int j = i + 1; j < updated_size;) {
            float iou = iou_rotate_calculate(bboxes[i], bboxes[j]);
            if (iou > nmsThreshold) {
                bboxes[j].confidence = -1;  
            }
            j++;
        }
    }
    bboxes.erase(std::remove_if(bboxes.begin(), bboxes.end(), [](const YOLOBbox &b) { return b.confidence < 0; }), bboxes.end());
}
