@@ -15,7 +15,7 @@ namespace module {
1515 */
1616class PRelu : public Module {
1717private:
18- TensorBase *alpha ;
18+ TensorBase *m_alpha ;
1919
2020public:
2121 /* *
@@ -29,19 +29,31 @@ class PRelu : public Module {
2929 TensorBase *alpha = NULL ,
3030 module_inplace_t inplace = MODULE_NON_INPLACE,
3131 quant_type_t quant_type = QUANT_TYPE_NONE) :
32- Module (name, inplace, quant_type), alpha (alpha)
32+ Module (name, inplace, quant_type), m_alpha (alpha)
3333 {
3434 }
3535
3636 /* *
3737 * @brief Destroy the PRelu object.
3838 */
39- ~PRelu () { delete this -> alpha ; }
39+ ~PRelu () { delete m_alpha ; }
4040
4141 std::vector<std::vector<int >> get_output_shape (std::vector<std::vector<int >> &input_shapes)
4242 {
4343 assert (input_shapes.size () == 1 );
44- assert (input_shapes[0 ][3 ] == this ->alpha ->shape [0 ]);
44+ if (m_alpha->shape [0 ] != input_shapes[0 ][3 ]) {
45+ TensorBase *new_alpha = new TensorBase (
46+ {input_shapes[0 ][3 ], 1 , 1 }, nullptr , m_alpha->exponent , m_alpha->dtype , true , m_alpha->caps );
47+ if (m_alpha->get_dtype () == DATA_TYPE_INT16) {
48+ int16_t alpha_value = m_alpha->get_element <int16_t >(0 );
49+ int16_t *alpha_ptr = new_alpha->get_element_ptr <int16_t >();
50+ for (int i = 0 ; i < input_shapes[0 ][3 ]; i++) {
51+ alpha_ptr[i] = alpha_value;
52+ }
53+ delete m_alpha;
54+ m_alpha = new_alpha;
55+ }
56+ }
4557 std::vector<std::vector<int >> output_shapes (1 , input_shapes[0 ]);
4658 return output_shapes;
4759 }
@@ -73,7 +85,7 @@ class PRelu : public Module {
7385 TensorBase *input = tensors[m_inputs_index[0 ]];
7486 TensorBase *output = tensors[m_outputs_index[0 ]];
7587
76- std::vector<base::ArgsType<T>> m_args = base::get_activation_args<T>(output, input, PReLU, alpha , mode);
88+ std::vector<base::ArgsType<T>> m_args = base::get_activation_args<T>(output, input, PReLU, m_alpha , mode);
7789 int task_size = m_args.size ();
7890 if (task_size == 1 ) { // single task
7991 forward_args ((void *)&m_args[0 ]);
@@ -93,19 +105,24 @@ class PRelu : public Module {
93105 quant_type_t quant_type;
94106 fbs_model->get_operation_attribute (node_name, " quant_type" , quant_type);
95107 TensorBase *alpha = fbs_model->get_operation_parameter (node_name, 1 );
96- TensorBase *table = fbs_model->get_operation_lut (node_name);
97108 // [c, 1, 1]
98109 assert (alpha->shape .size () == 3 );
99110
100111 // Create module
101- if (table != NULL ) {
102- op = new LUT (node_name.c_str (), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
103- if (alpha != nullptr ) {
104- delete alpha;
112+ if (quant_type == QUANT_TYPE_SYMM_8BIT) {
113+ TensorBase *table = fbs_model->get_operation_lut (node_name);
114+ if (table) {
115+ op = new LUT (node_name.c_str (), table, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
116+ if (alpha != nullptr ) {
117+ delete alpha;
118+ }
119+ } else {
120+ op = new PRelu (node_name.c_str (), alpha, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
105121 }
106- } else if (quant_type == QUANT_TYPE_SYMM_8BIT || quant_type == QUANT_TYPE_SYMM_16BIT) {
122+ } else {
107123 op = new PRelu (node_name.c_str (), alpha, MODULE_INPLACE_CHANGED_BUFFER, quant_type);
108124 }
125+
109126 return op;
110127 }
111128
0 commit comments