Skip to content

Commit 3292cbd

Browse files
committed
Merge branch 'test/int16' into 'master'
test: generate int16 test cases See merge request ai/esp-dl!112
2 parents a383981 + 470ab4e commit 3292cbd

File tree

17 files changed

+120
-54
lines changed

17 files changed

+120
-54
lines changed

.gitlab/ci/gen_test_cases.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ gen_espdl_ops_cases:
2828
- IMAGE: [python:3.11]
2929
TORCH: ["torch==2.5.0"]
3030
TARGET: [esp32p4]
31-
BITS: [8]
31+
BITS: [8, 16]
3232
- IMAGE: [python:3.10]
3333
TORCH: [torch]
3434
TARGET: [esp32s3]
35-
BITS: [8]
35+
BITS: [8, 16]
3636
variables:
3737
MODEL_PATH: test_apps/esp-dl/models
3838
CONFIG_FILE: tools/ops_test/config/op_cfg.toml

esp-dl/dl/model/include/dl_model_base.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class Model {
5656
* The address of model data while location is MODEL_LOCATION_IN_FLASH_RODATA.
5757
* The label of partition while location is MODEL_LOCATION_IN_FLASH_PARTITION.
5858
* The path of model while location is MODEL_LOCATION_IN_SDCARD.
59-
* @param location The model location.
6059
* @param model_index The model index of packed models.
60+
* @param location The model location.
6161
* @param internal_size Internal ram size, in bytes
6262
* @param mm_type Type of memory manager
6363
* @param key The key of encrypted model.
@@ -76,8 +76,8 @@ class Model {
7676
* The address of model data while location is MODEL_LOCATION_IN_FLASH_RODATA.
7777
* The label of partition while location is MODEL_LOCATION_IN_FLASH_PARTITION.
7878
* The path of model while location is MODEL_LOCATION_IN_SDCARD.
79-
* @param location The model location.
8079
* @param model_name The model name of packed models.
80+
* @param location The model location.
8181
* @param internal_size Internal ram size, in bytes
8282
* @param mm_type Type of memory manager
8383
* @param key The key of encrypted model.

esp-dl/dl/module/include/dl_module_clip.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,15 @@ class Clip : public Module {
103103
Module *op = nullptr;
104104
quant_type_t quant_type;
105105
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
106-
TensorBase *table = fbs_model->get_operation_lut(node_name);
107106

108107
// Create module
109-
if (table != NULL) {
110-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
111-
} else {
108+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
109+
TensorBase *table = fbs_model->get_operation_lut(node_name);
110+
if (table) {
111+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
112+
}
113+
}
114+
if (op == nullptr) {
112115
TensorBase *min = fbs_model->get_operation_parameter(node_name, 1);
113116
TensorBase *max = fbs_model->get_operation_parameter(node_name, 2);
114117
assert(min->exponent == max->exponent);

esp-dl/dl/module/include/dl_module_exp.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,18 @@ class Exp : public Module {
8585
Module *op = nullptr;
8686
quant_type_t quant_type;
8787
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
88-
TensorBase *table = fbs_model->get_operation_lut(node_name);
8988

9089
// Create module
91-
if (table != NULL) {
92-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
90+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
91+
TensorBase *table = fbs_model->get_operation_lut(node_name);
92+
if (table) {
93+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
94+
} else {
95+
op = new Exp(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
96+
}
9397
} else {
9498
op = new Exp(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
9599
}
96-
op->print();
97100

98101
return op;
99102
}

esp-dl/dl/module/include/dl_module_hard_sigmoid.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,15 @@ class HardSigmoid : public Module {
9090
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
9191
fbs_model->get_operation_attribute(node_name, "alpha", alpha);
9292
fbs_model->get_operation_attribute(node_name, "beta", beta);
93-
TensorBase *table = fbs_model->get_operation_lut(node_name);
9493

9594
// Create module
96-
if (table != NULL) {
97-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
95+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
96+
TensorBase *table = fbs_model->get_operation_lut(node_name);
97+
if (table) {
98+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
99+
} else {
100+
op = new HardSigmoid(node_name.c_str(), alpha, beta, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
101+
}
98102
} else {
99103
op = new HardSigmoid(node_name.c_str(), alpha, beta, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
100104
}

esp-dl/dl/module/include/dl_module_hard_swish.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,18 @@ class HardSwish : public Module {
7878
Module *op = nullptr;
7979
quant_type_t quant_type;
8080
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
81-
TensorBase *table = fbs_model->get_operation_lut(node_name);
8281

8382
// Create module
84-
if (table != NULL) {
85-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
83+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
84+
TensorBase *table = fbs_model->get_operation_lut(node_name);
85+
if (table) {
86+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
87+
} else {
88+
op = new HardSwish(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
89+
}
8690
} else {
8791
op = new HardSwish(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
8892
}
89-
op->print();
9093

9194
return op;
9295
}

esp-dl/dl/module/include/dl_module_leaky_relu.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ class LeakyRelu : public Module {
8888
float alpha = 0.01;
8989
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
9090
fbs_model->get_operation_attribute(node_name, "alpha", alpha);
91-
TensorBase *table = fbs_model->get_operation_lut(node_name);
9291

9392
// Create module
94-
if (table != NULL) {
95-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
93+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
94+
TensorBase *table = fbs_model->get_operation_lut(node_name);
95+
if (table) {
96+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
97+
} else {
98+
op = new LeakyRelu(node_name.c_str(), alpha, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
99+
}
96100
} else {
97101
op = new LeakyRelu(node_name.c_str(), alpha, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
98102
}

esp-dl/dl/module/include/dl_module_log.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,15 @@ class Log : public Module {
8484
Module *op = nullptr;
8585
quant_type_t quant_type;
8686
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
87-
TensorBase *table = fbs_model->get_operation_lut(node_name);
8887

8988
// Create module
90-
if (table != NULL) {
91-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
89+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
90+
TensorBase *table = fbs_model->get_operation_lut(node_name);
91+
if (table) {
92+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
93+
} else {
94+
op = new Log(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
95+
}
9296
} else {
9397
op = new Log(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
9498
}

esp-dl/dl/module/include/dl_module_prelu.hpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace module {
1515
*/
1616
class PRelu : public Module {
1717
private:
18-
TensorBase *alpha;
18+
TensorBase *m_alpha;
1919

2020
public:
2121
/**
@@ -29,19 +29,31 @@ class PRelu : public Module {
2929
TensorBase *alpha = NULL,
3030
module_inplace_t inplace = MODULE_NON_INPLACE,
3131
quant_type_t quant_type = QUANT_TYPE_NONE) :
32-
Module(name, inplace, quant_type), alpha(alpha)
32+
Module(name, inplace, quant_type), m_alpha(alpha)
3333
{
3434
}
3535

3636
/**
3737
* @brief Destroy the PRelu object.
3838
*/
39-
~PRelu() { delete this->alpha; }
39+
~PRelu() { delete m_alpha; }
4040

4141
std::vector<std::vector<int>> get_output_shape(std::vector<std::vector<int>> &input_shapes)
4242
{
4343
assert(input_shapes.size() == 1);
44-
assert(input_shapes[0][3] == this->alpha->shape[0]);
44+
if (m_alpha->shape[0] != input_shapes[0][3]) {
45+
TensorBase *new_alpha = new TensorBase(
46+
{input_shapes[0][3], 1, 1}, nullptr, m_alpha->exponent, m_alpha->dtype, true, m_alpha->caps);
47+
if (m_alpha->get_dtype() == DATA_TYPE_INT16) {
48+
int16_t alpha_value = m_alpha->get_element<int16_t>(0);
49+
int16_t *alpha_ptr = new_alpha->get_element_ptr<int16_t>();
50+
for (int i = 0; i < input_shapes[0][3]; i++) {
51+
alpha_ptr[i] = alpha_value;
52+
}
53+
delete m_alpha;
54+
m_alpha = new_alpha;
55+
}
56+
}
4557
std::vector<std::vector<int>> output_shapes(1, input_shapes[0]);
4658
return output_shapes;
4759
}
@@ -73,7 +85,7 @@ class PRelu : public Module {
7385
TensorBase *input = tensors[m_inputs_index[0]];
7486
TensorBase *output = tensors[m_outputs_index[0]];
7587

76-
std::vector<base::ArgsType<T>> m_args = base::get_activation_args<T>(output, input, PReLU, alpha, mode);
88+
std::vector<base::ArgsType<T>> m_args = base::get_activation_args<T>(output, input, PReLU, m_alpha, mode);
7789
int task_size = m_args.size();
7890
if (task_size == 1) { // single task
7991
forward_args((void *)&m_args[0]);
@@ -93,19 +105,24 @@ class PRelu : public Module {
93105
quant_type_t quant_type;
94106
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
95107
TensorBase *alpha = fbs_model->get_operation_parameter(node_name, 1);
96-
TensorBase *table = fbs_model->get_operation_lut(node_name);
97108
// [c, 1, 1]
98109
assert(alpha->shape.size() == 3);
99110

100111
// Create module
101-
if (table != NULL) {
102-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
103-
if (alpha != nullptr) {
104-
delete alpha;
112+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
113+
TensorBase *table = fbs_model->get_operation_lut(node_name);
114+
if (table) {
115+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
116+
if (alpha != nullptr) {
117+
delete alpha;
118+
}
119+
} else {
120+
op = new PRelu(node_name.c_str(), alpha, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
105121
}
106-
} else if (quant_type == QUANT_TYPE_SYMM_8BIT || quant_type == QUANT_TYPE_SYMM_16BIT) {
122+
} else {
107123
op = new PRelu(node_name.c_str(), alpha, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
108124
}
125+
109126
return op;
110127
}
111128

esp-dl/dl/module/include/dl_module_relu.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,15 @@ class Relu : public Module {
104104
Module *op = nullptr;
105105
quant_type_t quant_type;
106106
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
107-
TensorBase *table = fbs_model->get_operation_lut(node_name);
108107

109108
// Create module
110-
if (table != NULL) {
111-
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
109+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
110+
TensorBase *table = fbs_model->get_operation_lut(node_name);
111+
if (table) {
112+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
113+
} else {
114+
op = new Relu(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
115+
}
112116
} else {
113117
op = new Relu(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
114118
}

0 commit comments

Comments
 (0)