/**
 * @file face_detection.cpp
 * @brief Implementation of the FaceDetection class for running face detection inference using KPU and AI2D.
 * 
 * This module handles the complete face detection pipeline, including preprocessing,
 * model inference, post-processing, and result visualization (with OpenCV drawing utilities).
 * It supports different model input sizes (320 and 640) and uses pre-defined anchor grids.
 */

#include "face_detection.h"
#include "k230_math.h"

// External anchor arrays (precomputed for different input resolutions)
extern float kAnchors320[4200][4];
extern float kAnchors640[16800][4];

// Global pointer to the selected anchor set
static float (*g_anchors)[4];

/**
 * @brief Color list used for drawing on standard images (BGR).
 */
cv::Scalar color_list_for_det[] = {
    cv::Scalar(0, 0, 255),
    cv::Scalar(0, 255, 255),
    cv::Scalar(255, 0, 255),
    cv::Scalar(0, 255, 0),
    cv::Scalar(255, 0, 0)
};

/**
 * @brief Color list used for drawing on OSD display (BGRA).
 */
cv::Scalar color_list_for_osd_det[] = {
    cv::Scalar(255, 0, 0, 255),
    cv::Scalar(255, 255, 0, 255),
    cv::Scalar(0, 255, 255, 255),
    cv::Scalar(0, 255, 0, 255),
    cv::Scalar(0, 0, 255, 255)
};

/**
 * @brief Comparator used by qsort() to sort detected objects by descending confidence.
 * 
 * @param pa Pointer to first object
 * @param pb Pointer to second object
 * @return int Comparison result (used by qsort)
 */
int nms_comparator(const void *pa, const void *pb)
{
    NMSRoiObj a = *(NMSRoiObj *)pa;
    NMSRoiObj b = *(NMSRoiObj *)pb;
    float diff = a.confidence - b.confidence;

    if (diff < 0)
        return 1;
    else if (diff > 0)
        return -1;
    return 0;
}

// ---------------- public methods --------------------

/**
 * @brief Constructor for FaceDetection.
 * 
 * Initializes the face detection model, preprocessing pipeline (AI2D),
 * and necessary buffers for detection and landmark decoding.
 * 
 * @param kmodel_file  Path to the compiled KPU model (.kmodel)
 * @param obj_thresh   Object confidence threshold
 * @param nms_thresh   NMS (non-maximum suppression) IoU threshold
 * @param image_size   Input frame size from the camera (width, height)
 * @param debug_mode   Debug mode flag (enables profiling and logs)
 */
FaceDetection::FaceDetection(std::string &kmodel_file, float obj_thresh, float nms_thresh,
                             FrameCHWSize image_size, int debug_mode)
{
    obj_thresh_ = obj_thresh;
    nms_thresh_ = nms_thresh;
    image_size_ = image_size;
    debug_mode_ = debug_mode;

    // Initialize KPU model and AI2D preprocessor
    kpu = new KPU(kmodel_file);
    ai2d = new AI2D();

    // Obtain model input shape: [N, C, H, W]
    dims_t shape_0 = kpu->get_input_shape(0);
    input_size_ = {shape_0[2], shape_0[3], shape_0[1]};

    // Select anchor set based on model resolution
    min_size_ = (shape_0[2] == 320 ? 200 : 800);
    g_anchors = (shape_0[2] == 320 ? kAnchors320 : kAnchors640);

    // Allocate detection buffers
    objs_num_ = min_size_ * (1 + 4 + 16);
    so_ = new NMSRoiObj[objs_num_];
    boxes_ = new float[objs_num_ * LOC_SIZE];
    landmarks_ = new float[objs_num_ * LAND_SIZE];

    // Cache model input tensor
    model_input_tensor_ = kpu->get_input_tensor(0);

    // Calculate aspect ratio and padding for preprocessing
    float ratiow = (float)input_size_.width / image_size_.width;
    float ratioh = (float)input_size_.height / image_size_.height;
    float ratio = std::min(ratiow, ratioh);

    int new_w = static_cast<int>(ratio * image_size_.width);
    int new_h = static_cast<int>(ratio * image_size_.height);
    int top = 0, bottom = input_size_.height - new_h;
    int left = 0, right = input_size_.width - new_w;

    // Configure AI2D preprocessor
    ai2d->set_ai2d_dtype(ai2d_format::NCHW_FMT, ai2d_format::NCHW_FMT,
                         typecode_t::dt_uint8, typecode_t::dt_uint8);

    std::vector<int> pad_param = {0, 0, 0, 0, top, bottom, left, right};
    ai2d->set_pad(pad_param, ai2d_pad_mode::constant, {128, 128, 128});
    ai2d->set_resize(ai2d_interp_method::tf_bilinear, ai2d_interp_mode::half_pixel);

    // Build the preprocessing pipeline
    dims_t input_shape{1, 3, (long unsigned int)image_size_.height, (long unsigned int)image_size_.width};
    dims_t output_shape{1, 3, (long unsigned int)input_size_.height, (long unsigned int)input_size_.width};
    ai2d->build(input_shape, output_shape);
}

/**
 * @brief Destructor for FaceDetection.
 * 
 * Frees AI2D, KPU, and allocated buffers.
 */
FaceDetection::~FaceDetection()
{
    if (ai2d)
    {
        delete ai2d;
        ai2d = nullptr;
    }

    if (kpu)
    {
        delete kpu;
        kpu = nullptr;
    }

    delete[] so_;
    delete[] boxes_;
    delete[] landmarks_;
}

/**
 * @brief Preprocessing stage: prepares input tensor for inference.
 * 
 * @param input_tensor Input frame tensor (from camera)
 */
void FaceDetection::pre_process(runtime_tensor &input_tensor)
{
    PROFILE_SCOPE_AUTO(debug_mode_);
    ai2d->run(input_tensor, model_input_tensor_);
}

/**
 * @brief Performs model inference on KPU.
 */
void FaceDetection::inference()
{
    PROFILE_SCOPE_AUTO(debug_mode_);
    kpu->run();
}

/**
 * @brief Postprocessing stage: decodes model outputs into bounding boxes and landmarks.
 * 
 * Applies confidence thresholding, non-maximum suppression (NMS),
 * and maps detection results back to the input image size.
 * 
 * @param results Vector to store final detection results
 */
void FaceDetection::post_process(vector<FaceDetectionInfo> &results)
{
    PROFILE_SCOPE_AUTO(debug_mode_);

    // Get model output tensors
    float *output0 = (float *)kpu->get_output_ptr(0);
    float *output1 = (float *)kpu->get_output_ptr(1);
    float *output2 = (float *)kpu->get_output_ptr(2);
    float *output3 = (float *)kpu->get_output_ptr(3);
    float *output4 = (float *)kpu->get_output_ptr(4);
    float *output5 = (float *)kpu->get_output_ptr(5);
    float *output6 = (float *)kpu->get_output_ptr(6);
    float *output7 = (float *)kpu->get_output_ptr(7);
    float *output8 = (float *)kpu->get_output_ptr(8);

    // Decode confidence, location, and landmarks
    int obj_cnt = 0;
    deal_conf(output3, so_, 16 * min_size_ / 2, obj_cnt);
    deal_conf(output4, so_, 4 * min_size_ / 2, obj_cnt);
    deal_conf(output5, so_, 1 * min_size_ / 2, obj_cnt);

    obj_cnt = 0;
    deal_loc(output0, boxes_, 16 * min_size_ / 2, obj_cnt);
    deal_loc(output1, boxes_, 4 * min_size_ / 2, obj_cnt);
    deal_loc(output2, boxes_, 1 * min_size_ / 2, obj_cnt);

    obj_cnt = 0;
    deal_landms(output6, landmarks_, 16 * min_size_ / 2, obj_cnt);
    deal_landms(output7, landmarks_, 4 * min_size_ / 2, obj_cnt);
    deal_landms(output8, landmarks_, 1 * min_size_ / 2, obj_cnt);

    // Sort objects by confidence
    qsort(so_, objs_num_, sizeof(NMSRoiObj), nms_comparator);

    // Apply final box decoding and filtering
    get_final_box(image_size_, results);
}

/**
 * @brief Draw detected faces and landmarks on the given image frame.
 * 
 * Supports both RGB (cv::Mat with 3 channels) and OSD BGRA (4 channels) output modes.
 * 
 * @param draw_frame Image to draw results on
 * @param results Vector of detected faces
 */
void FaceDetection::draw_results(cv::Mat& draw_frame, vector<FaceDetectionInfo>& results)
{   
    PROFILE_SCOPE_AUTO(debug_mode_);
    int draw_frame_w = draw_frame.cols;
    int draw_frame_h = draw_frame.rows;
    int draw_frame_c = draw_frame.channels();
    int max_draw_frame_size = std::max(draw_frame_w, draw_frame_h);

    for (int i = 0; i < results.size(); ++i)
    {
        auto& l = results[i].sparse_kps;
        for (uint32_t ll = 0; ll < 5; ll++)
        {
            if (draw_frame_c == 3)
            {
                // Draw facial keypoints on BGR frame
                int32_t x0 = l.points[2 * ll + 0];
                int32_t y0 = l.points[2 * ll + 1];
                cv::circle(draw_frame, cv::Point(x0, y0), 2, color_list_for_det[ll], 4);
            }
            else
            {
                // Draw scaled keypoints for OSD frame
                int32_t x0 = l.points[2 * ll] / image_size_.width * draw_frame_w;
                int32_t y0 = l.points[2 * ll + 1] / image_size_.height * draw_frame_h;
                cv::circle(draw_frame, cv::Point(x0, y0), 4, color_list_for_osd_det[ll], 8);
            }
        }

        auto& b = results[i].bbox;
        char text[10];
        sprintf(text, "%.2f", results[i].score);

        if (draw_frame_c == 3)
        {
            // Draw bounding box and confidence text (BGR frame)
            cv::rectangle(draw_frame, cv::Rect(b.x, b.y, b.w, b.h), cv::Scalar(255, 255, 255), 2, 2, 0);
            cv::putText(draw_frame, text, {b.x, b.y}, cv::FONT_HERSHEY_COMPLEX, 0.5, cv::Scalar(0, 255, 255), 1, 8, 0);
        }
        else
        {
            // Draw scaled bounding box (OSD frame)
            int x = b.x / image_size_.width * draw_frame_w;
            int y = b.y / image_size_.height * draw_frame_h;
            int w = b.w / image_size_.width * draw_frame_w;
            int h = b.h / image_size_.height * draw_frame_h;
            cv::rectangle(draw_frame, cv::Rect(x, y, w, h), cv::Scalar(255, 255, 255, 255), 6, 2, 0);
        }        
    }
}

/**
 * @brief Full execution pipeline: preprocess → inference → postprocess.
 * 
 * @param input_tensor Input tensor (image frame)
 * @param results Output vector for face detection results
 */
void FaceDetection::run(runtime_tensor &input_tensor, vector<FaceDetectionInfo> &results)
{
    this->pre_process(input_tensor);
    this->inference();
    this->post_process(results);
}


// ---------------- protected methods --------------------

/**
 * @brief Perform Non-Maximum Suppression (NMS) and generate final face detection results.
 *
 * This function filters out overlapping bounding boxes based on IoU and the confidence threshold,
 * then rescales bounding boxes and facial landmarks to match the original image size.
 *
 * @param frame_size   The input frame size (width, height, channels).
 * @param results      The output vector that stores all valid detected faces and landmarks.
 */
void FaceDetection::get_final_box(FrameCHWSize &frame_size, vector<FaceDetectionInfo> &results)
{
    int iou_cal_times = 0;
    int i, j, obj_index;
    for (i = 0; i < objs_num_; ++i)
    {
        obj_index = so_[i].index;
        if (so_[i].confidence < obj_thresh_)
            continue;

        // Initialize detection info for a new object
        FaceDetectionInfo obj;
        obj.bbox = get_box(boxes_, obj_index);
        obj.sparse_kps = get_landmark(landmarks_, obj_index);

        // Perform NMS to remove overlapping boxes
        for (j = i + 1; j < objs_num_; ++j)
        {
            obj_index = so_[j].index;
            if (so_[j].confidence < obj_thresh_)
                continue;
            Bbox b = get_box(boxes_, obj_index);
            iou_cal_times += 1;
            if (box_iou(obj.bbox, b) >= nms_thresh_)
                so_[j].confidence = 0;
        }
        obj.score = so_[i].confidence;
        results.push_back(obj);
    }

    // Rescale detection boxes and landmarks to the source image size
    int max_src_size = std::max(frame_size.width, frame_size.height);
    for (int i = 0; i < results.size(); ++i)
    {
        // Scale landmarks
        auto &l = results[i].sparse_kps;
        for (uint32_t ll = 0; ll < 5; ll++)
        {
            l.points[2 * ll + 0] *= max_src_size;
            l.points[2 * ll + 1] *= max_src_size;
        }

        // Scale bounding boxes
        auto &b = results[i].bbox;
        float x1 = (b.x + b.w / 2) * max_src_size;
        float x0 = (b.x - b.w / 2) * max_src_size;
        float y0 = (b.y - b.h / 2) * max_src_size;
        float y1 = (b.y + b.h / 2) * max_src_size;

        // Clip coordinates within image boundary
        x1 = std::max(0.0f, std::min(x1, (float)frame_size.width));
        x0 = std::max(0.0f, std::min(x0, (float)frame_size.width));
        y0 = std::max(0.0f, std::min(y0, (float)frame_size.height));
        y1 = std::max(0.0f, std::min(y1, (float)frame_size.height));

        b.x = x0;
        b.y = y0;
        b.w = x1 - x0;
        b.h = y1 - y0;
    }
}

/**
 * @brief Compute softmax over a given float array.
 *
 * @param x     Input array of logits.
 * @param dx    Output array to store normalized softmax probabilities.
 * @param len   Length of the input array.
 */
void FaceDetection::local_softmax(float *x, float *dx, uint32_t len)
{
    float max_value = x[0];
    for (uint32_t i = 0; i < len; i++)
        if (max_value < x[i])
            max_value = x[i];

    // Exponentiate after subtracting max for numerical stability
    for (uint32_t i = 0; i < len; i++)
    {
        x[i] -= max_value;
        x[i] = expf(x[i]);
    }

    // Normalize to get probabilities
    float sum_value = 0.0f;
    for (uint32_t i = 0; i < len; i++)
        sum_value += x[i];

    for (uint32_t i = 0; i < len; i++)
        dx[i] = x[i] / sum_value;
}

/**
 * @brief Decode confidence scores from model output.
 *
 * This function applies softmax to each confidence value and stores them
 * in the NMSRoiObj array.
 *
 * @param conf      Input confidence tensor.
 * @param so        Output NMSRoiObj array.
 * @param size      Number of confidence groups.
 * @param obj_cnt   Reference counter tracking object index offset.
 */
void FaceDetection::deal_conf(float *conf, NMSRoiObj *so, int size, int &obj_cnt)
{
    float confidence[CONF_SIZE] = {0.0};
    for (uint32_t ww = 0; ww < size; ww++)
    {
        for (uint32_t hh = 0; hh < 2; hh++)
        {
            for (uint32_t cc = 0; cc < CONF_SIZE; cc++)
                confidence[cc] = conf[(hh * CONF_SIZE + cc) * size + ww];

            local_softmax(confidence, confidence, 2);
            so_[obj_cnt].index = obj_cnt;
            so_[obj_cnt].confidence = confidence[1];
            obj_cnt += 1;
        }
    }
}

/**
 * @brief Decode bounding box regression results from model output.
 *
 * @param loc       Input location tensor.
 * @param boxes     Output array to store decoded boxes.
 * @param size      Number of bounding boxes.
 * @param obj_cnt   Reference counter tracking object index offset.
 */
void FaceDetection::deal_loc(float *loc, float *boxes, int size, int &obj_cnt)
{
    for (uint32_t ww = 0; ww < size; ww++)
    {
        for (uint32_t hh = 0; hh < 2; hh++)
        {
            for (uint32_t cc = 0; cc < LOC_SIZE; cc++)
                boxes_[obj_cnt * LOC_SIZE + cc] = loc[(hh * LOC_SIZE + cc) * size + ww];
            obj_cnt += 1;
        }
    }
}

/**
 * @brief Decode facial landmark coordinates from model output.
 *
 * @param landms    Input landmarks tensor.
 * @param landmarks Output array to store decoded landmarks.
 * @param size      Number of landmark groups.
 * @param obj_cnt   Reference counter tracking object index offset.
 */
void FaceDetection::deal_landms(float *landms, float *landmarks, int size, int &obj_cnt)
{
    // Convert CHW layout to HWC
    for (uint32_t ww = 0; ww < size; ww++)
    {
        for (uint32_t hh = 0; hh < 2; hh++)
        {
            for (uint32_t cc = 0; cc < LAND_SIZE; cc++)
                landmarks_[obj_cnt * LAND_SIZE + cc] = landms[(hh * LAND_SIZE + cc) * size + ww];
            obj_cnt += 1;
        }
    }
}

/**
 * @brief Decode a single bounding box using anchor information.
 *
 * @param boxes     Pointer to raw box predictions.
 * @param obj_index Index of the target object.
 * @return Bbox     Decoded bounding box.
 */
Bbox FaceDetection::get_box(float *boxes, int obj_index)
{
    float cx = boxes_[obj_index * LOC_SIZE + 0];
    float cy = boxes_[obj_index * LOC_SIZE + 1];
    float w  = boxes_[obj_index * LOC_SIZE + 2];
    float h  = boxes_[obj_index * LOC_SIZE + 3];

    // Decode using anchors
    cx = g_anchors[obj_index][0] + cx * 0.1f * g_anchors[obj_index][2];
    cy = g_anchors[obj_index][1] + cy * 0.1f * g_anchors[obj_index][3];
    w  = g_anchors[obj_index][2] * k230_expf(w * 0.2f);
    h  = g_anchors[obj_index][3] * k230_expf(h * 0.2f);

    Bbox box;
    box.x = cx;
    box.y = cy;
    box.w = w;
    box.h = h;
    return box;
}

/**
 * @brief Decode a single face's 5-point landmarks using anchor information.
 *
 * @param landmarks  Pointer to raw landmark predictions.
 * @param obj_index  Index of the target object.
 * @return SparseLandmarks  Structure containing 5 decoded landmark points.
 */
SparseLandmarks FaceDetection::get_landmark(float *landmarks, int obj_index)
{
    SparseLandmarks landmark;
    for (uint32_t ll = 0; ll < 5; ll++)
    {
        landmark.points[2 * ll + 0] = g_anchors[obj_index][0] + landmarks_[obj_index * LAND_SIZE + 2 * ll + 0] * 0.1f * g_anchors[obj_index][2];
        landmark.points[2 * ll + 1] = g_anchors[obj_index][1] + landmarks_[obj_index * LAND_SIZE + 2 * ll + 1] * 0.1f * g_anchors[obj_index][3];
    }
    return landmark;
}

/**
 * @brief Compute 1D overlap between two boxes.
 *
 * @param x1  Center of the first box.
 * @param w1  Width of the first box.
 * @param x2  Center of the second box.
 * @param w2  Width of the second box.
 * @return Overlapping length.
 */
float FaceDetection::overlap(float x1, float w1, float x2, float w2)
{
    float l1 = x1 - w1 / 2;
    float l2 = x2 - w2 / 2;
    float left = std::max(l1, l2);
    float r1 = x1 + w1 / 2;
    float r2 = x2 + w2 / 2;
    float right = std::min(r1, r2);
    return right - left;
}

/**
 * @brief Calculate the intersection area between two bounding boxes.
 *
 * @param a  First bounding box.
 * @param b  Second bounding box.
 * @return Intersection area.
 */
float FaceDetection::box_intersection(Bbox a, Bbox b)
{
    float w = overlap(a.x, a.w, b.x, b.w);
    float h = overlap(a.y, a.h, b.y, b.h);
    if (w < 0 || h < 0)
        return 0;
    return w * h;
}

/**
 * @brief Calculate the union area between two bounding boxes.
 *
 * @param a  First bounding box.
 * @param b  Second bounding box.
 * @return Union area.
 */
float FaceDetection::box_union(Bbox a, Bbox b)
{
    float i = box_intersection(a, b);
    float u = a.w * a.h + b.w * b.h - i;
    return u;
}

/**
 * @brief Compute Intersection-over-Union (IoU) between two boxes.
 *
 * @param a  First bounding box.
 * @param b  Second bounding box.
 * @return IoU ratio in range [0, 1].
 */
float FaceDetection::box_iou(Bbox a, Bbox b)
{
    return box_intersection(a, b) / box_union(a, b);
}
