# KPU (Knowledge Processing Unit) 接口开发文档

## KPU 模块概述

`KPU` 类是对 `nncase` kmodel 推理接口的封装，提供了对模型输入输出张量的管理、推理执行以及输出数据访问功能。
该类封装了模型加载、输入输出张量初始化、推理执行以及输出数据指针获取等常用操作，简化了模型推理流程。

**主要特性：**

* 支持加载 kmodel 模型文件
* 自动初始化输入输出张量及其形状
* 提供设置单个或多个输入张量的接口
* 执行推理并获取输出张量对象或原始数据指针
* 禁止拷贝构造与赋值，支持移动构造与移动赋值

---

## KPU 类定义

```cpp
class KPU
{
public:
    explicit KPU(const std::string& kmodel_file);
    ~KPU();

    KPU(const KPU&) = default;
    KPU& operator=(const KPU&) = default;

    KPU(KPU&&) noexcept = default;
    KPU& operator=(KPU&&) noexcept = default;

    int get_input_size() noexcept;
    int get_output_size() noexcept;
    dims_t get_input_shape(int idx);
    dims_t get_output_shape(int idx);
    datatype_t get_input_typecode(int idx);
    datatype_t get_output_typecode(int idx);
    size_t get_input_data_size(int idx);
    size_t get_output_data_size(int idx);
    size_t get_input_data_bytes(int idx);
    size_t get_output_data_bytes(int idx);
    void set_input_tensor(int idx, const nncase::runtime::runtime_tensor& input_tensor);
    void set_input_tensors(const std::vector<nncase::runtime::runtime_tensor>& input_tensors);
    nncase::runtime::runtime_tensor get_input_tensor(int idx);
    void run();
    nncase::runtime::runtime_tensor get_output_tensor(int idx);
    char* get_output_ptr(int idx);
    std::vector<char> get_output_data(int idx);

private:
    std::string kmodel_file_;
    nncase::runtime::interpreter kmodel_interp_;
    int input_size_{0};
    int output_size_{0};
    std::vector<dims_t> input_shapes_;
    std::vector<dims_t> output_shapes_;
    std::vector<typecode_t> input_typecodes_;
    std::vector<typecode_t> output_typecodes_;
    std::vector<size_t> input_data_size_;
    std::vector<size_t> output_data_size_;
    std::vector<size_t> input_data_bytes_;
    std::vector<size_t> output_data_bytes_;
    std::vector<std::vector<char>> output_data_;
};
```

---

## KPU类接口说明

### KPU类构造函数

```cpp
explicit KPU(const std::string& kmodel_file);
```

**功能：**
加载指定 kmodel 文件，并初始化输入输出张量及其形状。

**参数：**

| 参数名          | 类型            | 描述          |
| ------------ | ------------- | ----------- |
| kmodel\_file | `std::string` | kmodel 文件路径 |

**异常：**

* 文件打开失败：`std::runtime_error`
* 模型加载失败：`std::runtime_error`
* 文件关闭失败：`std::runtime_error`

---

### KPU类析构函数

```cpp
~KPU();
```

**功能：**
销毁 KPU 对象，释放内部资源。

---

### 获取输入数量

```cpp
int get_input_size() noexcept;
```

**功能：**
返回模型的输入张量数量。

**返回值：**

* `int`：输入张量数量

---

### 获取输出数量

```cpp
int get_output_size() noexcept;
```

**功能：**
返回模型的输出张量数量。

**返回值：**

* `int`：输出张量数量

---

### 获取指定输入张量形状

```cpp
dims_t get_input_shape(int idx);
```

**功能：**
返回指定索引输入张量的形状。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输入张量索引 |

**返回值：**

* `dims_t`：输入张量维度

**异常：**

* 索引越界时：`std::out_of_range`

---

### 获取指定输出张量形状

```cpp
dims_t get_output_shape(int idx);
```

**功能：**
返回指定索引输出张量的形状。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `dims_t`：输出张量维度

**异常：**

* 索引越界时：`std::out_of_range`

---

### 获取指定输入张量数据类型

```cpp
datatype_t get_input_typecode(int idx);
```

**功能：**
返回指定索引输入张量的数据类型。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输入张量索引 |

**返回值：**

* `datatype_t`：输入张量数据类型

**异常：**

* 索引越界时：`std::out_of_range`

**补充：**

| 返回值(强转int) | 数据类型 | 对应 C++ 类型                   | 占用字节数 |
|--------|-------------|---------------------------------|------------|
| 0      | `dt_boolean`   | `bool`                          | 1          |
| 1      | `dt_utf8char`  | `char` / `u8char`               | 1          |
| 2      | `dt_int8`      | `int8_t`                        | 1          |
| 3      | `dt_int16`     | `int16_t`                       | 2          |
| 4      | `dt_int32`     | `int32_t`                       | 4          |
| 5      | `dt_int64`     | `int64_t`                       | 8          |
| 6      | `dt_uint8`     | `uint8_t`                       | 1          |
| 7      | `dt_uint16`    | `uint16_t`                      | 2          |
| 8      | `dt_uint32`    | `uint32_t`                      | 4          |
| 9      | `dt_uint64`    | `uint64_t`                      | 8          |
| 10     | `dt_float16`   | 半精度浮点 (`_Float16`)          | 2          |
| 11     | `dt_float32`   | `float`                         | 4          |
| 12     | `dt_float64`   | `double`                        | 8          |
| 13     | `dt_bfloat16`  | `bfloat16`                      | 2          |

---

### 获取指定输出张量数据类型

```cpp
datatype_t get_output_typecode(int idx);
```

**功能：**
返回指定索引输出张量的数据类型。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `datatype_t`：输出张量数据类型

**异常：**

* 索引越界时：`std::out_of_range`

**补充：**

| 返回值(强转int) | 数据类型 | 对应 C++ 类型                   | 占用字节数 |
|--------|-------------|---------------------------------|------------|
| 0      | `dt_boolean`   | `bool`                          | 1          |
| 1      | `dt_utf8char`  | `char` / `u8char`               | 1          |
| 2      | `dt_int8`      | `int8_t`                        | 1          |
| 3      | `dt_int16`     | `int16_t`                       | 2          |
| 4      | `dt_int32`     | `int32_t`                       | 4          |
| 5      | `dt_int64`     | `int64_t`                       | 8          |
| 6      | `dt_uint8`     | `uint8_t`                       | 1          |
| 7      | `dt_uint16`    | `uint16_t`                      | 2          |
| 8      | `dt_uint32`    | `uint32_t`                      | 4          |
| 9      | `dt_uint64`    | `uint64_t`                      | 8          |
| 10     | `dt_float16`   | 半精度浮点 (`_Float16`)          | 2          |
| 11     | `dt_float32`   | `float`                         | 4          |
| 12     | `dt_float64`   | `double`                        | 8          |
| 13     | `dt_bfloat16`  | `bfloat16`                      | 2          |

---

### 获取指定输入张量数据大小

```cpp
size_t get_input_data_size(int idx);
```

**功能：**
返回指定索引输入张量的数据大小。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输入张量索引 |

**返回值：**

* `size_t`：输入张量数据大小

**异常：**

* 索引越界时：`std::out_of_range`

---

### 获取指定输出张量数据大小

```cpp
size_t get_output_data_size(int idx);
```

**功能：**
返回指定索引输出张量的数据大小。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `size_t`：输出张量数据大小

**异常：**

* 索引越界时：`std::out_of_range`

---

### 获取指定输入的字节数

```cpp
size_t get_input_data_bytes(int idx);
```

**功能：**
返回指定索引输入张量的数据字节数。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输入张量索引 |

**返回值：**

* `size_t`：输入张量数据字节数

**异常：**

* 索引越界时：`std::out_of_range`

---

### 获取指定输出的字节数

```cpp
size_t get_output_data_bytes(int idx);
```

**功能：**
返回指定索引输出张量的数据字节数。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `size_t`：输出张量数据字节数

**异常：**

* 索引越界时：`std::out_of_range`

---

### 设置单个输入张量

```cpp
void set_input_tensor(int idx, const nncase::runtime::runtime_tensor& input_tensor);
```

**功能：**
将指定索引的输入张量设置为用户提供的数据。

**参数：**

| 参数名           | 类型               | 描述     |
| ------------- | ---------------- | ------ |
| idx           | `int`            | 输入张量索引 |
| input\_tensor | `runtime_tensor` | 输入张量对象 |

**返回值：**

无

**异常：**

* 索引越界时：`std::out_of_range`
* 设置失败时：`std::runtime_error`

---

### 设置多个输入张量

```cpp
void set_input_tensors(const std::vector<nncase::runtime::runtime_tensor>& input_tensors);
```

**功能：**
批量设置所有输入张量。

**参数：**

| 参数名            | 类型                            | 描述                     |
| -------------- | ----------------------------- | ---------------------- |
| input\_tensors | `std::vector<runtime_tensor>` | 输入张量对象数组，数量必须与模型输入数量一致 |

**返回值：**

无

**异常：**

* 数量不匹配时：`std::invalid_argument`
* 设置失败时：`std::runtime_error`

---

### 获取指定输入张量

```cpp
nncase::runtime::runtime_tensor get_input_tensor(int idx);
```

**功能：**
返回指定索引的输入张量对象。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输入张量索引 |

**返回值：**

* `runtime_tensor`：输入张量对象

**异常：**

* 索引越界：`std::out_of_range`
* 获取失败：`std::runtime_error`

---

### 执行推理

```cpp
void run();
```

**功能：**
执行模型推理，并映射输出张量数据指针。

**返回值：**

无

**异常：**

* 推理失败或输出张量获取失败：`std::runtime_error`

---

### 获取指定输出张量

```cpp
nncase::runtime::runtime_tensor get_output_tensor(int idx);
```

**功能：**
返回指定索引的输出张量对象。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `runtime_tensor`：输出张量对象

**异常：**

* 索引越界：`std::out_of_range`
* 获取失败：`std::runtime_error`

---

### 获取输出张量原始数据指针

```cpp
void* get_output_ptr(int idx);
```

**功能：**
返回指定索引输出张量的原始数据指针，便于直接访问输出数据。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `char*`：输出数据指针

**异常：**

* 索引越界：`std::out_of_range`
* 输出张量未映射：`std::runtime_error`

---

### 获取指定输出张量数据

```cpp
std::vector<char> get_output_data(int idx);
```

**功能：**
返回指定索引输出张量的原始数据，用于获取输出数据。

**参数：**

| 参数名 | 类型    | 描述     |
| --- | ----- | ------ |
| idx | `int` | 输出张量索引 |

**返回值：**

* `std::vector<char>`：输出数据

**异常：**

* 索引越界：`std::out_of_range`
* 输出张量未映射：`std::runtime_error`

---
