#include "kws.h"

KWS::KWS(std::string kmodel_file, float thresh, int debug_mode)
{   
    spot_thresh=thresh;
    debug_mode_=debug_mode;
    num_keyword=2;
    feature_config = new wenet::FeaturePipelineConfig(40, 16000);
    feature_pipeline = new wenet::FeaturePipeline(*feature_config);
    kpu = new KPU(kmodel_file);
    model_input_tensor_wav = kpu -> get_input_tensor(0);
    model_input_tensor_cache = kpu -> get_input_tensor(1);
    cache.assign(hidden_dim, std::vector<float>(cache_dim, 0.0f));
}

KWS::~KWS(){
    if (kpu)
    {
        delete kpu;
        kpu = nullptr;
    }
    if (feature_pipeline)
    {
        delete feature_pipeline;
        feature_pipeline = nullptr;
    }
}

void KWS::pre_process(std::vector<float>& wav)
{
    PROFILE_SCOPE_AUTO(debug_mode_);
    feature_pipeline->AcceptWaveform(wav);

    std::vector<std::vector<float>> feats;
    bool ok = feature_pipeline->Read(chunk_size, &feats);
    if (!ok) return;

    // --- [1] 预分配 flatten 缓冲，避免多次扩容 ---
    const size_t wav_feature_size = chunk_size * num_bin;
    std::vector<float> flattened_feats;
    flattened_feats.reserve(wav_feature_size);

    for (const auto& inner : feats) {
        flattened_feats.insert(flattened_feats.end(), inner.begin(), inner.end());
    }

    // --- [2] flatten cache 同理 ---
    const size_t cache_size = hidden_dim * cache_dim;
    std::vector<float> flattened_cache;
    flattened_cache.reserve(cache_size);

    for (const auto& inner : cache) {
        flattened_cache.insert(flattened_cache.end(), inner.begin(), inner.end());
    }

    // --- [3] 缓存映射，只执行一次 unwrap/map ---
    auto host_buf_wav = model_input_tensor_wav.impl()
                            ->to_host().unwrap()
                            ->buffer().as_host().unwrap()
                            .map(map_access_::map_write).unwrap();

    auto* dst_wav = reinterpret_cast<float*>(host_buf_wav.buffer().data());
    std::memcpy(dst_wav, flattened_feats.data(), wav_feature_size * sizeof(float));

    auto host_buf_cache = model_input_tensor_cache.impl()
                              ->to_host().unwrap()
                              ->buffer().as_host().unwrap()
                              .map(map_access_::map_write).unwrap();

    auto* dst_cache = reinterpret_cast<float*>(host_buf_cache.buffer().data());
    std::memcpy(dst_cache, flattened_cache.data(), cache_size * sizeof(float));
}

void KWS::inference()
{   
    PROFILE_SCOPE_AUTO(debug_mode_);
    kpu->run();
}

int KWS::post_process()
{
    PROFILE_SCOPE_AUTO(debug_mode_);
    float* res_ptr   = reinterpret_cast<float*>(kpu->get_output_ptr(0));
    float* cache_ptr = reinterpret_cast<float*>(kpu->get_output_ptr(1));

    // 1. 高效拷贝 cache_ptr → cache
    for (size_t i = 0; i < cache.size(); ++i) {
        std::copy(cache_ptr + i * cache[0].size(),
                cache_ptr + (i + 1) * cache[0].size(),
                cache[i].begin());
    }

    // 2. 求最大值，得到长度为2的向量
    std::array<float, 2> scores_max;
    scores_max.fill(std::numeric_limits<float>::lowest());
    for (int r = 0; r < 30; ++r) {
        for (int c = 0; c < num_keyword; ++c) {
                scores_max[c] = std::max(scores_max[c], res_ptr[r * num_keyword + c]);
        }
    }

    // 3. 在 scores_max 上求全局 max 和 argmax
    int maxIndex = 0;
    float maxVal = scores_max[0];
    for (int r = 1; r < 2; ++r) {
        if (scores_max[r] > maxVal) {
            maxVal = scores_max[r];
            maxIndex = r;
        }
    }
    // 4. 判断唤醒结果
    if (maxVal > spot_thresh) {
        if (maxIndex == 0){
            std::cout << "Deactivated!" << std::endl;
        }
        else{
            std::cout << "XiaonanXiaonan!" << std::endl;
        }
    } else {
        std::cout << "Deactivated!" << std::endl;
    }
    return maxIndex;
}





