#include <iostream>
#include "ai_utils.h"

using std::ofstream;
using std::vector;

/**
 * @brief Get color list for classes by cycling through predefined color palette.
 * 
 * @param num_classes Number of object classes.
 * @return std::vector<cv::Scalar> Color list corresponding to each class.
 */
std::vector<cv::Scalar> get_colors_for_classes(int num_classes) {
    std::vector<cv::Scalar> colors;
    int num_available_colors = color_four.size(); 
    for (int i = 0; i < num_classes; ++i) {
        colors.push_back(color_four[i % num_available_colors]);
    }
    return colors;
}

/**
 * @brief Convert a cv::Mat image to a runtime_tensor in CHW or NHWC format.
 * 
 * @param ori_img Input image as a cv::Mat.
 * @param mode Conversion mode: "CHW" for channel-first format, "HWC" for channel-last format.
 * @return runtime_tensor Corresponding tensor with image data.
 */
runtime_tensor mat_to_tensor(cv::Mat &ori_img, std::string mode) {
    runtime_tensor image_tensor;
    try {
        size_t H = ori_img.rows;
        size_t W = ori_img.cols;
        size_t C = ori_img.channels();

        if (mode == "CHW") {
            std::vector<uint8_t> chw_vec(C * H * W);
            uint8_t *ptr = chw_vec.data();
            std::vector<cv::Mat> channels(C);
            cv::split(ori_img, channels);
            for (size_t c = 0; c < C; ++c) {
                const uint8_t *row_ptr;
                for (int i = 0; i < H; ++i) {
                    row_ptr = channels[C - 1 - c].ptr<uint8_t>(i);
                    memcpy(ptr, row_ptr, W);
                    ptr += W;
                }
            }

            dims_t image_shape {1, C, H, W};
            image_tensor = host_runtime_tensor::create(
                typecode_t::dt_uint8,
                image_shape,
                { reinterpret_cast<gsl::byte*>(chw_vec.data()), C * H * W },
                true,
                hrt::pool_shared
            ).expect("cannot create input tensor");
        } else if (mode == "HWC") {
            dims_t image_shape {1, H, W, C};
            image_tensor = host_runtime_tensor::create(
                typecode_t::dt_uint8,
                image_shape,
                { reinterpret_cast<gsl::byte*>(ori_img.data), H * W * C },
                true,
                hrt::pool_shared
            ).expect("cannot create input tensor");
        }
        else {
            throw std::runtime_error("Error: Invalid mode. Use 'CHW' or 'HWC'.");
        }
        hrt::sync(image_tensor, sync_op_t::sync_write_back, true).expect("write back input failed");
    } catch (std::exception &e) {
        throw std::runtime_error(std::string("Error: Failed to convert image to tensor: ") + e.what());
    }
    return image_tensor;
}
