#include "yolo11.h"

/**
 * @brief Yolo11 Constructor
 *
 * Initializes the YOLO11 object, loads KPU model, sets up AI2D preprocessing,
 * and configures task-specific parameters for classification, detection, segmentation, or obb.
 *
 * @param task_type      Task type: "classify", "detect", "segment" or "obb"
 * @param task_mode      Task mode: "video" or "image"
 * @param kmodel_file    Path to the KModel file
 * @param image_wh       Original image size (width, height)
 * @param labels         Class label names
 * @param conf_thres     Confidence threshold for detections
 * @param nms_thres      Non-Maximum Suppression (NMS) threshold
 * @param mask_thres     Mask threshold for segmentation
 * @param debug_mode     Debug level: 0 = off, 1 = time only, 2 = verbose
 */
Yolo11::Yolo11(std::string task_type, std::string task_mode, std::string kmodel_file,
               FrameSize image_wh, std::vector<std::string> labels,
               float conf_thres, float nms_thres, float mask_thres, int debug_mode)
{
    task_type_ = task_type;
    task_mode_ = task_mode;
    conf_thres_ = conf_thres;
    nms_thres_ = nms_thres;
    mask_thres_ = mask_thres;
    image_wh_ = image_wh;

    // Initialize KPU model and AI2D preprocessor
    kpu = new KPU(kmodel_file);
    ai2d_yolo = new AI2D();

    // Determine model input shape
    dims_t shape_i = kpu->get_input_shape(0);
    input_wh_ = {(int)shape_i[3], (int)shape_i[2]};
    
    labels_ = labels;
    label_num_ = labels_.size();
    colors = get_colors_for_classes(label_num_);

    max_box_num_ = 50;

    // Estimate number of detection boxes
    box_num_ = ((input_wh_.width / 8) * (input_wh_.height / 8) +
                (input_wh_.width / 16) * (input_wh_.height / 16) +
                (input_wh_.width / 32) * (input_wh_.height / 32));

    debug_mode_ = debug_mode;
    model_input_tensor_ = kpu->get_input_tensor(0);

    // Print model input and output shapes if debug mode is enabled
    if(debug_mode_>0){
        printf("-----------------shape info-----------------\n");
        int input_size=kpu->get_input_size();
        for(int i=0;i<input_size;i++){
            dims_t shape_i=kpu->get_input_shape(i);
            printf("input %d shape: [ ",i);
            for(int j=0;j<shape_i.size();j++){
                printf("%d ",shape_i[j]);
            }
            printf("]\n");
        }
        int output_size=kpu->get_output_size();
        for(int i=0;i<output_size;i++){
            dims_t shape_o=kpu->get_output_shape(i);
            printf("output %d shape: [ ",i);
            for(int j=0;j<shape_o.size();j++){
                printf("%d ",shape_o[j]);
            }
            printf("]\n");
        }
        printf("--------------shape info end-----------------\n");
    }

    // Configure preprocessing based on task type
    if (task_type_ == "classify")
    {
        int min_m = image_wh.width<image_wh.height?input_wh_.width:input_wh_.height;
        int top = (image_wh.height - min_m) / 2;
        int left = (image_wh.width - min_m) / 2;

        ai2d_yolo->set_ai2d_dtype(ai2d_format::NCHW_FMT, ai2d_format::NCHW_FMT,
                                  typecode_t::dt_uint8, typecode_t::dt_uint8);
        ai2d_yolo->set_crop(left, top, min_m, min_m);
        ai2d_yolo->set_resize(ai2d_interp_method::tf_bilinear, ai2d_interp_mode::half_pixel);

        dims_t input_shape{1, 3, (long unsigned int)image_wh.height, (long unsigned int)image_wh.width};
        dims_t output_shape{1, 3, (long unsigned int)input_wh_.height, (long unsigned int)input_wh_.width};
        ai2d_yolo->build(input_shape, output_shape);
    }
    else if (task_type_ == "detect" || task_type_ == "segment" || task_type_ == "obb")
    {
        if (task_type_ == "detect"){
            box_feature_len_ = label_num_ + 4;
        }
        else if (task_type_ == "segment"){
            box_feature_len_ = label_num_ + 4 + 32;
        }
        else{
            box_feature_len_ = label_num_ + 5;
        }

        // Allocate memory for detection output
        output_det = new float[box_num_ * box_feature_len_];

        float ratiow = (float)input_wh_.width / image_wh_.width;
        float ratioh = (float)input_wh_.height / image_wh_.height;
        float ratio = std::min(ratiow, ratioh);

        int new_w = static_cast<int>(ratio * image_wh_.width);
        int new_h = static_cast<int>(ratio * image_wh_.height);

        int top = 0, bottom = input_wh_.height - new_h;
        int left = 0, right = input_wh_.width - new_w;

        ai2d_yolo->set_ai2d_dtype(ai2d_format::NCHW_FMT, ai2d_format::NCHW_FMT,
                                  typecode_t::dt_uint8, typecode_t::dt_uint8);

        std::vector<int> pad_param = {0, 0, 0, 0, top, bottom, left, right};
        ai2d_yolo->set_pad(pad_param, ai2d_pad_mode::constant, {128, 128, 128});
        ai2d_yolo->set_resize(ai2d_interp_method::tf_bilinear, ai2d_interp_mode::half_pixel);

        dims_t input_shape{1, 3, (long unsigned int)image_wh.height, (long unsigned int)image_wh.width};
        dims_t output_shape{1, 3, (long unsigned int)input_wh_.height, (long unsigned int)input_wh_.width};
        ai2d_yolo->build(input_shape, output_shape);
    }
    else
    {
        std::cerr << "Unsupported task type: " << task_type_ << std::endl;
        exit(EXIT_FAILURE);
    }
}

/**
 * @brief Destructor
 *
 * Releases dynamically allocated resources for KPU and AI2D.
 */
Yolo11::~Yolo11()
{
    if (ai2d_yolo)
    {
        delete ai2d_yolo;
        ai2d_yolo = nullptr;
    }

    if (kpu)
    {
        delete kpu;
        kpu = nullptr;
    }

    // Free detection output memory
    if (output_det)
    {
        delete[] output_det;
        output_det = nullptr;
    }
}

/**
 * @brief Preprocess input image tensor
 *
 * Runs AI2D preprocessing on the input image and prepares it for KPU inference.
 *
 * @param input_tensor Input tensor containing image data
 */
void Yolo11::pre_process(runtime_tensor &input_tensor)
{
    PROFILE_SCOPE("Yolo11::pre_process", debug_mode_);
    ai2d_yolo->run(input_tensor, model_input_tensor_);
}

/**
 * @brief Run model inference
 *
 * Executes the YOLO11 KPU inference on the preprocessed input tensor.
 */
void Yolo11::inference()
{
    PROFILE_SCOPE("Yolo11::inference", debug_mode_);
    kpu->run();
}

/**
 * @brief Postprocess model outputs
 *
 * Decodes raw KPU outputs, applies confidence thresholding and NMS,
 * and generates bounding boxes for classification, detection, or segmentation.
 *
 * @param yolo_results Output vector to store detection results
 */
void Yolo11::post_process(std::vector<YOLOBbox> &yolo_results)
{
    PROFILE_SCOPE("Yolo11::post_process", debug_mode_);
    yolo_results.clear();

    if (task_type_ == "classify")
    {
        // Handle classification task output
        float *output0 = (float *)kpu->get_output_ptr(0);
        YOLOBbox res;

        if (label_num_ > 2)
        {
            // Softmax for multi-class classification
            float sum = 0.0f;
            for (int i = 0; i < label_num_; i++)
                sum += exp(output0[i]);

            for (int i = 0; i < label_num_; i++)
                output0[i] = exp(output0[i]) / sum;

            int max_index = std::max_element(output0, output0 + label_num_) - output0;

            if (output0[max_index] >= conf_thres_)
            {
                res.index = max_index;
                res.confidence = output0[max_index];
                yolo_results.push_back(res);
            }
        }
        else
        {
            // Binary classification
            float pred = sigmoid(output0[0]);
            if (pred > conf_thres_)
            {
                res.index = 0;
                res.confidence = pred;
            }
            else
            {
                res.index = 1;
                res.confidence = 1 - pred;
            }
            yolo_results.push_back(res);
        }
    }
    else if (task_type_ == "detect")
    {
        // Handle detection task output
        float ratiow = (float)input_wh_.width / image_wh_.width;
        float ratioh = (float)input_wh_.height / image_wh_.height;
        float ratio = std::min(ratiow, ratioh);

        memset(output_det, 0, sizeof(float) * box_num_ * box_feature_len_);
        float* output0= (float *)kpu->get_output_ptr(0);
        // [label_num_+4,(w/8)*(h/8)+(w/16)*(h/16)+(w/32)*(h/32)] to [(w/8)*(h/8)+(w/16)*(h/16)+(w/32)*(h/32),label_num_+4]
        for(int r = 0; r < box_num_; r++)
        {
            for(int c = 0; c < box_feature_len_; c++)
            {
                output_det[r*box_feature_len_ + c] = output0[c*box_num_ + r];
            }
        }

        for (int i = 0; i < box_num_; i++)
        {
            float *vec = output_det + i * box_feature_len_;
            float box[4] = {vec[0], vec[1], vec[2], vec[3]};
            float *class_scores = vec + 4;
            float *max_class_score_ptr = std::max_element(class_scores, class_scores + label_num_);
            float score = (*max_class_score_ptr);
            int max_class_index = max_class_score_ptr - class_scores;

            if (score > conf_thres_)
            {
                YOLOBbox bbox;
                float x_ = box[0] / ratio;
                float y_ = box[1] / ratio;
                float w_ = box[2] / ratio;
                float h_ = box[3] / ratio;

                int x = std::max(int(x_ - 0.5f * w_), 0);
                int y = std::max(int(y_ - 0.5f * h_), 0);
                int w = int(w_);
                int h = int(h_);
                if (w <= 0 || h <= 0)
                    continue;

                bbox.box = cv::Rect(x, y, w, h);
                bbox.confidence = score;
                bbox.index = max_class_index;
                yolo_results.push_back(bbox);
            }
        }

        // Perform Non-Maximum Suppression (NMS)
        std::vector<int> nms_result;
        nms(yolo_results, conf_thres_, nms_thres_, nms_result);
    }
    else if (task_type_ == "segment")
    {
        // Handle segmentation task output
        float ratiow = input_wh_.width / (float)image_wh_.width;
        float ratioh = input_wh_.height / (float)image_wh_.height;
        float ratio = std::min(ratiow, ratioh);

        int new_w = int(image_wh_.width * ratio);
        int new_h = int(image_wh_.height * ratio);
        int pad_w = std::max(input_wh_.width - new_w, 0);
        int pad_h = std::max(input_wh_.height - new_h, 0);

        memset(output_det, 0, sizeof(float) * box_num_ * box_feature_len_);
        float* output0= (float *)kpu->get_output_ptr(0);
        // [label_num_+4+32,(w/8)*(h/8)+(w/16)*(h/16)+(w/32)*(h/32)] to [(w/8)*(h/8)+(w/16)*(h/16)+(w/32)*(h/32),label_num_+4+32]
        for(int r = 0; r < box_num_; r++)
        {
            for(int c = 0; c < box_feature_len_; c++)
            {
                output_det[r*box_feature_len_ + c] = output0[c*box_num_ + r];
            }
        }

        float* output1=(float *)kpu->get_output_ptr(1);

        int mask_w = input_wh_.width / 4;
        int mask_h = input_wh_.height / 4;

        cv::Mat protos(32, mask_w * mask_h, CV_32FC1, output1);

        for (int i = 0; i < box_num_; i++)
        {
            float *vec = output_det + i * box_feature_len_;
            float box[4] = {vec[0], vec[1], vec[2], vec[3]};
            float *class_scores = vec + 4;
            float *max_class_score_ptr = std::max_element(class_scores, class_scores + label_num_);
            float score = (*max_class_score_ptr);
            int max_class_index = max_class_score_ptr - class_scores;

            if (score > conf_thres_)
            {
                YOLOBbox bbox;
                float x_ = box[0] / ratio;
                float y_ = box[1] / ratio;
                float w_ = box[2] / ratio;
                float h_ = box[3] / ratio;

                int x = std::max(int(x_ - 0.5f * w_), 0);
                int y = std::max(int(y_ - 0.5f * h_), 0);
                int w = int(w_);
                int h = int(h_);
                if (w <= 0 || h <= 0)
                    continue;

                bbox.box = cv::Rect(x, y, w, h);
                bbox.confidence = score;
                bbox.index = max_class_index;

                bbox.mask=cv::Mat(1, 32, CV_32F, vec + label_num_ + 4);
                yolo_results.push_back(bbox);
            }
        }

        // Perform Non-Maximum Suppression (NMS)
        std::vector<int> nms_result;
        nms(yolo_results, conf_thres_, nms_thres_, nms_result);
        for (int i = 0; i < yolo_results.size(); i++)
        {
            cv::Mat mask_box = yolo_results[i].mask * protos;
            cv::Mat mask_box_(mask_h, mask_w, CV_32FC1, mask_box.data);
            cv::Rect roi(0, 0, mask_w - int(pad_w * (mask_w / float(input_wh_.width))),
                        mask_h - int(pad_h * (mask_h / float(input_wh_.height))));

            cv::Mat dest;
            cv::exp(-mask_box_, dest);
            dest = 1.0 / (1.0 + dest);
            dest = dest(roi);
            yolo_results[i].mask = dest;
        }
    }
    else if (task_type_ == "obb"){
        // Handle obb task output
        float ratiow = input_wh_.width / (float)image_wh_.width;
        float ratioh = input_wh_.height / (float)image_wh_.height;
        float ratio = std::min(ratiow, ratioh);

        memset(output_det, 0, sizeof(float) * box_num_ * box_feature_len_);
        float* output0= (float *)kpu->get_output_ptr(0);
        // [label_num_+5,(w/8)*(h/8)+(w/16)*(h/16)+(w/32)*(h/32)] to [(w/8)*(h/8)+(w/16)*(h/16)+(w/32)*(h/32),label_num_+5]
        for(int r = 0; r < box_num_; r++)
        {
            for(int c = 0; c < box_feature_len_; c++)
            {
                output_det[r*box_feature_len_ + c] = output0[c*box_num_ + r];
            }
        }

        for (int i = 0; i < box_num_; i++)
        {
            float *vec = output_det + i * box_feature_len_;
            float box[4] = {vec[0], vec[1], vec[2], vec[3]};
            float *class_scores = vec + 4;
            float *max_class_score_ptr = std::max_element(class_scores, class_scores + label_num_);
            float score = (*max_class_score_ptr);
            int max_class_index = max_class_score_ptr - class_scores;
            float angle=vec[4+label_num_];

            if (score > conf_thres_)
            {
                YOLOBbox bbox;
                float x_ = box[0] / ratio;
                float y_ = box[1] / ratio;
                float w_ = box[2] / ratio;
                float h_ = box[3] / ratio;

                int x = std::max(int(x_), 0);
                int y = std::max(int(y_), 0);
                int w = int(w_);
                int h = int(h_);
                if (w <= 0 || h <= 0)
                    continue;

                bbox.box = cv::Rect(x, y, w, h);
                bbox.confidence = score;
                bbox.angle=angle;
                bbox.index = max_class_index;
                yolo_results.push_back(bbox);
            }
        }

        // Perform Non-Maximum Suppression (NMS)
        std::vector<int> nms_result;
        rotate_nms(yolo_results, conf_thres_, nms_thres_, nms_result);
    }
    else
    {
        std::cerr << "Unsupported task type: " << task_type_ << std::endl;
        exit(EXIT_FAILURE);
    }
}

/**
 * @brief Draw detection results on the frame
 *
 * Renders bounding boxes, class labels, confidence scores, and segmentation masks, obb boxes
 * depending on the task type.
 *
 * @param draw_frame    Frame to draw on
 * @param yolo_results  Vector of detection results
 */
void Yolo11::draw_results(cv::Mat &draw_frame, std::vector<YOLOBbox> &yolo_results)
{
    PROFILE_SCOPE("Yolo11::draw_results", debug_mode_);

    int w_ = draw_frame.cols;
    int h_ = draw_frame.rows;
    int res_size = std::min((int)yolo_results.size(), max_box_num_);

    if (task_type_ == "classify")
    {
        if (res_size > 0)
        {
            std::string text = labels_[yolo_results[0].index] +
                               " score:" + std::to_string(yolo_results[0].confidence);
            cv::putText(draw_frame, text, cv::Point(50, 50), cv::FONT_HERSHEY_DUPLEX, 1,
                        colors[yolo_results[0].index], 2, 0);
        }
    }
    else if (task_type_ == "detect" || task_type_ == "segment")
    {
        for (int i = 0; i < res_size; i++)
        {
            YOLOBbox box_ = yolo_results[i];
            cv::Rect box = box_.box;
            int idx = box_.index;
            float score = box_.confidence;

            int x = int(box.x * float(w_) / image_wh_.width);
            int y = int(box.y * float(h_) / image_wh_.height);
            int w = int(box.width * float(w_) / image_wh_.width);
            int h = int(box.height * float(h_) / image_wh_.height);

            w = std::min(w, w_ - x);
            h = std::min(h, h_ - y);

            cv::Rect new_box(x, y, w, h);
            cv::rectangle(draw_frame, new_box, colors[idx], 2, 8);
            cv::putText(draw_frame, labels_[idx] + " " + std::to_string(score),
                        cv::Point(std::min(new_box.x + 5, w_), std::max(new_box.y - 10, 0)),
                        cv::FONT_HERSHEY_DUPLEX, 1, colors[idx], 2, 0);

            // Draw mask if segmentation task
            if (task_type_ == "segment")
            {
                cv::Mat mask = box_.mask;
                cv::Mat mask_d;
                cv::resize(mask, mask_d, cv::Size(w_, h_), cv::INTER_NEAREST);
                mask_d = mask_d(new_box) > mask_thres_;
                draw_frame(new_box).setTo(colors[idx], mask_d);
            }
        }
    }
    else if (task_type_ == "obb")
    {
        for(int i=0;i<res_size;i++){
            YOLOBbox box_=yolo_results[i];
            cv::Rect box=box_.box;
            int idx=box_.index;
            float score=box_.confidence;
            float angle=box_.angle;

            std::vector<std::pair<int, int>> corners=calculate_obb_corners(box.x, box.y,box.width, box.height, angle);
            int x_0=int(corners[0].first*float(w_)/image_wh_.width);
            int y_0=int(corners[0].second*float(h_)/image_wh_.height);
            int x_1=int(corners[1].first*float(w_)/image_wh_.width);
            int y_1=int(corners[1].second*float(h_)/image_wh_.height);
            int x_2=int(corners[2].first*float(w_)/image_wh_.width);
            int y_2=int(corners[2].second*float(h_)/image_wh_.height);
            int x_3=int(corners[3].first*float(w_)/image_wh_.width);
            int y_3=int(corners[3].second*float(h_)/image_wh_.height);
            cv::line(draw_frame, cv::Point(x_0, y_0), cv::Point(x_1, y_1), colors[idx], 2);
            cv::line(draw_frame, cv::Point(x_1, y_1), cv::Point(x_2, y_2), colors[idx], 2);
            cv::line(draw_frame, cv::Point(x_2, y_2), cv::Point(x_3, y_3), colors[idx], 2);
            cv::line(draw_frame, cv::Point(x_3, y_3), cv::Point(x_0, y_0), colors[idx], 2);
            cv::putText(draw_frame, std::to_string(idx), cv::Point(x_0 , y_0 - 10), cv::FONT_HERSHEY_DUPLEX, 1, colors[idx], 2, 0);
        }
    }
    else
    {
        std::cerr << "Unsupported task type: " << task_type_ << std::endl;
        exit(EXIT_FAILURE);
    }
}

/**
 * @brief Run complete detection pipeline
 *
 * Executes preprocessing, inference, and postprocessing in sequence.
 *
 * @param input_tensor  Input image tensor.
 * @param yolo_results  Output vector for detected bounding boxes.
 */
void Yolo11::run(runtime_tensor &input_tensor,std::vector<YOLOBbox> &yolo_results){
    this->pre_process(input_tensor);
    this->inference();
    this->post_process(yolo_results);
}