Skip to content

Commit 25f9740

Browse files
committed
Merge branch 'Swish' into 'master'
Swish See merge request ai/esp-dl!189
2 parents e3ab57f + a6d6077 commit 25f9740

File tree

4 files changed

+117
-1
lines changed

4 files changed

+117
-1
lines changed

esp-dl/dl/module/include/dl_module_creator.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "dl_module_sqrt.hpp"
4040
#include "dl_module_squeeze.hpp"
4141
#include "dl_module_sub.hpp"
42+
#include "dl_module_swish.hpp"
4243
#include "dl_module_tanh.hpp"
4344
#include "dl_module_transpose.hpp"
4445
#include "dl_module_unsqueeze.hpp"
@@ -152,6 +153,7 @@ class ModuleCreator {
152153
this->register_module("LessOrEqual", LessOrEqual::deserialize);
153154
this->register_module("ReverseSequence", ReverseSequence::deserialize);
154155
this->register_module("Identity", Identity::deserialize);
156+
this->register_module("Swish", Swish::deserialize);
155157
}
156158
}
157159

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#pragma once
2+
3+
#include "dl_math.hpp"
4+
#include "dl_module_base.hpp"
5+
#include "dl_module_lut.hpp"
6+
7+
namespace dl {
8+
namespace module {
9+
/**
10+
* NOTE:
11+
*
12+
* @tparam feature_t supports int16_t and int8_t,
13+
* - int16_t: stands for operation in int16_t, implemented by LUT
14+
* - int8_t: stands for operation in int16_t, implemented by LUT
15+
* y = x * sigmoid(x)
16+
*/
17+
class Swish : public Module {
18+
public:
19+
/**
20+
* @brief Construct a new Swish object.
21+
*
22+
* @param name name of module
23+
* @param inplace inplace type.
24+
*/
25+
Swish(const char *name = NULL,
26+
module_inplace_t inplace = MODULE_NON_INPLACE,
27+
quant_type_t quant_type = QUANT_TYPE_NONE) :
28+
Module(name, inplace, quant_type)
29+
{
30+
}
31+
32+
/**
33+
* @brief Destroy the Swish object.
34+
*/
35+
~Swish() {}
36+
37+
std::vector<std::vector<int>> get_output_shape(std::vector<std::vector<int>> &input_shapes)
38+
{
39+
std::vector<std::vector<int>> output_shapes(1, input_shapes[0]);
40+
return output_shapes;
41+
}
42+
43+
void forward(ModelContext *context, runtime_mode_t mode = RUNTIME_MODE_AUTO)
44+
{
45+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
46+
forward_template<int8_t>(context, mode);
47+
} else if (quant_type == QUANT_TYPE_SYMM_16BIT) {
48+
forward_template<int16_t>(context, mode);
49+
}
50+
}
51+
52+
template <typename T>
53+
void forward_template(ModelContext *context, runtime_mode_t mode)
54+
{
55+
TensorBase *input = context->get_tensor(m_inputs_index[0]);
56+
TensorBase *output = context->get_tensor(m_outputs_index[0]);
57+
T *input_ptr = (T *)input->get_element_ptr();
58+
T *output_ptr = (T *)output->get_element_ptr();
59+
60+
float input_scale = DL_SCALE(input->exponent);
61+
float output_scale = DL_RESCALE(output->exponent);
62+
for (size_t i = 0; i < input->size; i++) {
63+
float temp = input_ptr[i] * input_scale;
64+
temp = dl::math::sigmoid(temp) * temp;
65+
tool::truncate(output_ptr[i], tool::round(temp * output_scale));
66+
}
67+
}
68+
69+
void forward_args(void *args) {}
70+
71+
/**
72+
* @brief deserialize Swish module instance by node serialization information
73+
*/
74+
static Module *deserialize(fbs::FbsModel *fbs_model, std::string node_name)
75+
{
76+
Module *op = nullptr;
77+
quant_type_t quant_type;
78+
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
79+
80+
// Create module
81+
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
82+
TensorBase *table = fbs_model->get_operation_lut(node_name);
83+
if (table) {
84+
op = new LUT(node_name.c_str(), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
85+
} else {
86+
op = new Swish(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
87+
}
88+
} else {
89+
op = new Swish(node_name.c_str(), MODULE_INPLACE_CHANGED_BUFFER, quant_type);
90+
}
91+
92+
return op;
93+
}
94+
95+
void print() { ESP_LOGI("Swish", "quant_type: %s.", quant_type_to_string(quant_type)); }
96+
};
97+
} // namespace module
98+
} // namespace dl

tools/ops_test/config/op_cfg.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2445,7 +2445,13 @@
24452445
time_axis = 1
24462446
export_name_prefix = "ReverseSequence_ishape_1_5_20"
24472447

2448-
2448+
[ops_test.Swish]
2449+
test_func = "SWISH_TEST"
2450+
quant_bits = ["int8", "int16"]
2451+
package = "torch_ops_test"
2452+
[[ops_test.Swish.cfg]]
2453+
input_shape = [1, 96, 20, 20]
2454+
export_name_prefix = "swish_ishap_1_96_20_20"
24492455

24502456

24512457
[models_test]

tools/ops_test/torch_ops_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,16 @@ def forward(self, input):
858858
return output
859859

860860

861+
class SWISH_TEST(nn.Module):
862+
def __init__(self, config):
863+
super().__init__()
864+
self.config = config
865+
self.op = nn.SiLU()
866+
867+
def forward(self, input):
868+
return self.op(input)
869+
870+
861871
if __name__ == "__main__":
862872
print(f"Test {os.path.basename(sys.argv[0])} Module Start...")
863873

0 commit comments

Comments
 (0)