#include "face_mask.h"

// ------------------public methods------------------
FaceMask::FaceMask(std::string &kmodel_file,float mask_thresh, FrameCHWSize image_size,int debug_mode)
{
    image_size_=image_size;
    mask_thresh_=mask_thresh;
    debug_mode_ = debug_mode;

    // Initialize KPU model and AI2D preprocessor
    kpu = new KPU(kmodel_file);
    ai2d = new AI2D();

    // Obtain model input shape: [N, C, H, W]
    dims_t shape_0 = kpu->get_input_shape(0);
    input_size_ = {shape_0[2], shape_0[3], shape_0[1]};

    // Cache model input tensor
    model_input_tensor_ = kpu->get_input_tensor(0);
}

FaceMask::~FaceMask()
{
	if (ai2d)
    {
        delete ai2d;
        ai2d = nullptr;
    }

    if (kpu)
    {
        delete kpu;
        kpu = nullptr;
    }
}

void FaceMask::pre_process(runtime_tensor& input_tensor, float* sparse_points){
	PROFILE_SCOPE_AUTO(debug_mode_);
	get_affine_matrix(sparse_points);
	// Configure AI2D preprocessor
    ai2d->set_ai2d_dtype(ai2d_format::NCHW_FMT, ai2d_format::NCHW_FMT,
                         typecode_t::dt_uint8, typecode_t::dt_uint8);

    
    ai2d->set_affine(ai2d_interp_method::cv2_bilinear, 0, 0, 127, 1, {matrix_dst_[0], matrix_dst_[1], matrix_dst_[2], matrix_dst_[3], matrix_dst_[4], matrix_dst_[5]});

    // Build the preprocessing pipeline
    dims_t input_shape{1, 3, (long unsigned int)image_size_.height, (long unsigned int)image_size_.width};
    dims_t output_shape{1, 3, (long unsigned int)input_size_.height, (long unsigned int)input_size_.width};
    ai2d->build(input_shape, output_shape);
    ai2d->run(input_tensor, model_input_tensor_);
}

void FaceMask::inference()
{
	PROFILE_SCOPE_AUTO(debug_mode_);
    kpu->run();
}

void FaceMask::post_process(FaceMaskInfo& result)
{
    PROFILE_SCOPE_AUTO(debug_mode_);

	float* output_ptr_0= reinterpret_cast<float*>(kpu->get_output_ptr(0));

	vector<float> pred(output_ptr_0,output_ptr_0+kpu->get_output_data_size(0));

    vector<float> softmax_pred_vec;
    softmax(pred,softmax_pred_vec);
    result.score = softmax_pred_vec[1];
    if(result.score < mask_thresh_)
    {
        result.label = "no mask";
    }
    else
    {
        result.label = "mask";
    }
}

void FaceMask::draw_results(cv::Mat& draw_frame,Bbox& bbox,FaceMaskInfo& result)
{
    int src_w = draw_frame.cols;
    int src_h = draw_frame.rows;
	int src_c = draw_frame.channels();
    int max_src_size = std::max(src_w,src_h);

    char text[30];

    if(src_c == 3)
    {
        cv::rectangle(draw_frame, cv::Rect(bbox.x, bbox.y , bbox.w, bbox.h), cv::Scalar(255, 255, 255), 2, 2, 0);
        if(result.score<mask_thresh_){
			sprintf(text, "%s:%.2f",result.label.c_str(), 1-result.score);
			cv::putText(draw_frame, text , {bbox.x,std::max(int(bbox.y-10),0)}, cv::FONT_HERSHEY_COMPLEX, 0.6, cv::Scalar(255, 0, 0), 1, 8, 0);
		}
		else{
			sprintf(text, "%s:%.2f",result.label.c_str(), result.score);
			cv::putText(draw_frame, text , {bbox.x,std::max(int(bbox.y-10),0)}, cv::FONT_HERSHEY_COMPLEX, 0.6, cv::Scalar(0, 0, 255), 1, 8, 0);
		}	
    }
    else
    {		
		int x = bbox.x / image_size_.width * src_w;
        int y = bbox.y / image_size_.height * src_h;
        int w = bbox.w / image_size_.width * src_w;
        int h = bbox.h / image_size_.height  * src_h;
        cv::rectangle(draw_frame, cv::Rect(x, y , w, h), cv::Scalar(255,255, 255, 255), 2, 2, 0);
        if(result.label == "no mask"){
			sprintf(text, "%s:%.2f",result.label.c_str(), 1-result.score);
			cv::putText(draw_frame,text,cv::Point(x,std::max(int(y-10),0)),cv::FONT_HERSHEY_COMPLEX,2,cv::Scalar(0,0, 255, 255), 2, 8, 0);
		}
		else{
			sprintf(text, "%s:%.2f",result.label.c_str(), result.score);
			cv::putText(draw_frame,text,cv::Point(x,std::max(int(y-10),0)),cv::FONT_HERSHEY_COMPLEX,2,cv::Scalar(0,255, 0, 255), 2, 8, 0);
		}
    }  
}

void FaceMask::run(runtime_tensor& input_tensor,float* sparse_points, FaceMaskInfo& result)
{
	this->pre_process(input_tensor,sparse_points);
	this->inference();
	this->post_process(result);
}

// ------------------protected methods------------------

void FaceMask::svd22(const float a[4], float u[4], float s[2], float v[4])
{
	s[0] = (sqrtf(powf(a[0] - a[3], 2) + powf(a[1] + a[2], 2)) + sqrtf(powf(a[0] + a[3], 2) + powf(a[1] - a[2], 2))) / 2;
	s[1] = fabsf(s[0] - sqrtf(powf(a[0] - a[3], 2) + powf(a[1] + a[2], 2)));
	v[2] = (s[0] > s[1]) ? sinf((atan2f(2 * (a[0] * a[1] + a[2] * a[3]), a[0] * a[0] - a[1] * a[1] + a[2] * a[2] - a[3] * a[3])) / 2) : 0;
	v[0] = sqrtf(1 - v[2] * v[2]);
	v[1] = -v[2];
	v[3] = v[0];
	u[0] = (s[0] != 0) ? -(a[0] * v[0] + a[1] * v[2]) / s[0] : 1;
	u[2] = (s[0] != 0) ? -(a[2] * v[0] + a[3] * v[2]) / s[0] : 0;
	u[1] = (s[1] != 0) ? (a[0] * v[1] + a[1] * v[3]) / s[1] : -u[2];
	u[3] = (s[1] != 0) ? (a[2] * v[1] + a[3] * v[3]) / s[1] : u[0];
	v[0] = -v[0];
	v[2] = -v[2];
}

static float umeyama_args_128[] =
{
#define PIC_SIZE 128
	38.2946 * PIC_SIZE / 112,  51.6963 * PIC_SIZE / 112,
	73.5318 * PIC_SIZE / 112, 51.5014 * PIC_SIZE / 112,
	56.0252 * PIC_SIZE / 112, 71.7366 * PIC_SIZE / 112,
	41.5493 * PIC_SIZE / 112, 92.3655 * PIC_SIZE / 112,
	70.7299 * PIC_SIZE / 112, 92.2041 * PIC_SIZE / 112
};

void FaceMask::image_umeyama_128(float* src, float* dst)
{
#define SRC_NUM 5
#define SRC_DIM 2
	int i, j, k;
	float src_mean[SRC_DIM] = { 0.0 };
	float dst_mean[SRC_DIM] = { 0.0 };
	for (i = 0; i < SRC_NUM * 2; i += 2)
	{
		src_mean[0] += src[i];
		src_mean[1] += src[i + 1];
		dst_mean[0] += umeyama_args_128[i];
		dst_mean[1] += umeyama_args_128[i + 1];
	}
	src_mean[0] /= SRC_NUM;
	src_mean[1] /= SRC_NUM;
	dst_mean[0] /= SRC_NUM;
	dst_mean[1] /= SRC_NUM;

	float src_demean[SRC_NUM][2] = { 0.0 };
	float dst_demean[SRC_NUM][2] = { 0.0 };

	for (i = 0; i < SRC_NUM; i++)
	{
		src_demean[i][0] = src[2 * i] - src_mean[0];
		src_demean[i][1] = src[2 * i + 1] - src_mean[1];
		dst_demean[i][0] = umeyama_args_128[2 * i] - dst_mean[0];
		dst_demean[i][1] = umeyama_args_128[2 * i + 1] - dst_mean[1];
	}

	float A[SRC_DIM][SRC_DIM] = { 0.0 };
	for (i = 0; i < SRC_DIM; i++)
	{
		for (k = 0; k < SRC_DIM; k++)
		{
			for (j = 0; j < SRC_NUM; j++)
			{
				A[i][k] += dst_demean[j][i] * src_demean[j][k];
			}
			A[i][k] /= SRC_NUM;
		}
	}

	float(*T)[SRC_DIM + 1] = (float(*)[SRC_DIM + 1])dst;
	T[0][0] = 1;
	T[0][1] = 0;
	T[0][2] = 0;
	T[1][0] = 0;
	T[1][1] = 1;
	T[1][2] = 0;
	T[2][0] = 0;
	T[2][1] = 0;
	T[2][2] = 1;

	float U[SRC_DIM][SRC_DIM] = { 0 };
	float S[SRC_DIM] = { 0 };
	float V[SRC_DIM][SRC_DIM] = { 0 };
	svd22(&A[0][0], &U[0][0], S, &V[0][0]);

	T[0][0] = U[0][0] * V[0][0] + U[0][1] * V[1][0];
	T[0][1] = U[0][0] * V[0][1] + U[0][1] * V[1][1];
	T[1][0] = U[1][0] * V[0][0] + U[1][1] * V[1][0];
	T[1][1] = U[1][0] * V[0][1] + U[1][1] * V[1][1];

	float scale = 1.0;
	float src_demean_mean[SRC_DIM] = { 0.0 };
	float src_demean_var[SRC_DIM] = { 0.0 };
	for (i = 0; i < SRC_NUM; i++)
	{
		src_demean_mean[0] += src_demean[i][0];
		src_demean_mean[1] += src_demean[i][1];
	}
	src_demean_mean[0] /= SRC_NUM;
	src_demean_mean[1] /= SRC_NUM;

	for (i = 0; i < SRC_NUM; i++)
	{
		src_demean_var[0] += (src_demean_mean[0] - src_demean[i][0]) * (src_demean_mean[0] - src_demean[i][0]);
		src_demean_var[1] += (src_demean_mean[1] - src_demean[i][1]) * (src_demean_mean[1] - src_demean[i][1]);
	}
	src_demean_var[0] /= (SRC_NUM);
	src_demean_var[1] /= (SRC_NUM);
	scale = 1.0 / (src_demean_var[0] + src_demean_var[1]) * (S[0] + S[1]);
	T[0][2] = dst_mean[0] - scale * (T[0][0] * src_mean[0] + T[0][1] * src_mean[1]);
	T[1][2] = dst_mean[1] - scale * (T[1][0] * src_mean[0] + T[1][1] * src_mean[1]);
	T[0][0] *= scale;
	T[0][1] *= scale;
	T[1][0] *= scale;
	T[1][1] *= scale;
	float(*TT)[3] = (float(*)[3])T;
}

void FaceMask::get_affine_matrix(float* sparse_points)
{
    float matrix_src[5][2];
    for (uint32_t i = 0; i < 5; ++i)
    {
        matrix_src[i][0] = sparse_points[2 * i + 0];
		matrix_src[i][1] = sparse_points[2 * i + 1];
    }
    image_umeyama_128(&matrix_src[0][0], &matrix_dst_[0]);
}

void FaceMask::softmax(vector<float>& input,vector<float>& output)
{
    //e_x = np.exp(x - np.max(x))
    std::vector<float>::iterator p_input_max = std::max_element(input.begin(), input.end());
    float input_max = *p_input_max;
    float input_total = 0;
    
    for(auto x:input)
	{
		input_total+=exp( x- input_max);
	}

    output.resize(input.size());
	for(int i=0;i<input.size();++i)
	{
		output[i] = exp(input[i] - input_max)/input_total;
	}
}