|
16 | 16 | using namespace mmdeploy;
|
17 | 17 | using namespace std;
|
18 | 18 |
|
19 |
| -int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, |
20 |
| - mmdeploy_classifier_t* classifier) { |
21 |
| - mmdeploy_context_t context{}; |
22 |
| - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); |
23 |
| - if (ec != MMDEPLOY_SUCCESS) { |
| 19 | +int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_classifier_t* classifier) |
| 20 | +{ |
| 21 | + mmdeploy_context_t context{}; |
| 22 | + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); |
| 23 | + if (ec != MMDEPLOY_SUCCESS) |
| 24 | + { |
| 25 | + return ec; |
| 26 | + } |
| 27 | + ec = mmdeploy_classifier_create_v2(model, context, classifier); |
| 28 | + mmdeploy_context_destroy(context); |
24 | 29 | return ec;
|
25 |
| - } |
26 |
| - ec = mmdeploy_classifier_create_v2(model, context, classifier); |
27 |
| - mmdeploy_context_destroy(context); |
28 |
| - return ec; |
29 | 30 | }
|
30 | 31 |
|
31 |
| -int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, |
32 |
| - int device_id, mmdeploy_classifier_t* classifier) { |
33 |
| - mmdeploy_model_t model{}; |
| 32 | +int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_classifier_t* classifier) |
| 33 | +{ |
| 34 | + mmdeploy_model_t model{}; |
34 | 35 |
|
35 |
| - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { |
| 36 | + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) |
| 37 | + { |
| 38 | + return ec; |
| 39 | + } |
| 40 | + auto ec = mmdeploy_classifier_create(model, device_name, device_id, classifier); |
| 41 | + mmdeploy_model_destroy(model); |
36 | 42 | return ec;
|
37 |
| - } |
38 |
| - auto ec = mmdeploy_classifier_create(model, device_name, device_id, classifier); |
39 |
| - mmdeploy_model_destroy(model); |
40 |
| - return ec; |
41 | 43 | }
|
42 | 44 |
|
43 |
| -int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, |
44 |
| - mmdeploy_classifier_t* classifier) { |
45 |
| - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier); |
| 45 | +int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_classifier_t* classifier) |
| 46 | +{ |
| 47 | + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier); |
46 | 48 | }
|
47 | 49 |
|
48 |
| -int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, |
49 |
| - mmdeploy_value_t* value) { |
50 |
| - return mmdeploy_common_create_input(mats, mat_count, value); |
| 50 | +int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) |
| 51 | +{ |
| 52 | + return mmdeploy_common_create_input(mats, mat_count, value); |
51 | 53 | }
|
52 | 54 |
|
53 |
| -int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, const mmdeploy_mat_t* mats, |
54 |
| - int mat_count, mmdeploy_classification_t** results, |
55 |
| - int** result_count) { |
56 |
| - wrapped<mmdeploy_value_t> input; |
57 |
| - if (auto ec = mmdeploy_classifier_create_input(mats, mat_count, input.ptr())) { |
58 |
| - return ec; |
59 |
| - } |
60 |
| - wrapped<mmdeploy_value_t> output; |
61 |
| - if (auto ec = mmdeploy_classifier_apply_v2(classifier, input, output.ptr())) { |
62 |
| - return ec; |
63 |
| - } |
64 |
| - if (auto ec = mmdeploy_classifier_get_result(output, results, result_count)) { |
65 |
| - return ec; |
66 |
| - } |
67 |
| - return MMDEPLOY_SUCCESS; |
| 55 | +int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_classification_t** results, int** result_count) |
| 56 | +{ |
| 57 | + wrapped<mmdeploy_value_t> input; |
| 58 | + if (auto ec = mmdeploy_classifier_create_input(mats, mat_count, input.ptr())) |
| 59 | + { |
| 60 | + return ec; |
| 61 | + } |
| 62 | + wrapped<mmdeploy_value_t> output; |
| 63 | + if (auto ec = mmdeploy_classifier_apply_v2(classifier, input, output.ptr())) |
| 64 | + { |
| 65 | + return ec; |
| 66 | + } |
| 67 | + if (auto ec = mmdeploy_classifier_get_result(output, results, result_count)) |
| 68 | + { |
| 69 | + return ec; |
| 70 | + } |
| 71 | + return MMDEPLOY_SUCCESS; |
68 | 72 | }
|
69 | 73 |
|
70 |
| -int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, mmdeploy_value_t input, |
71 |
| - mmdeploy_value_t* output) { |
72 |
| - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)classifier, input, output); |
| 74 | +int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, mmdeploy_value_t input, mmdeploy_value_t* output) |
| 75 | +{ |
| 76 | + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)classifier, input, output); |
73 | 77 | }
|
74 | 78 |
|
75 |
| -int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, mmdeploy_sender_t input, |
76 |
| - mmdeploy_sender_t* output) { |
77 |
| - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)classifier, input, output); |
| 79 | +int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, mmdeploy_sender_t input, mmdeploy_sender_t* output) |
| 80 | +{ |
| 81 | + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)classifier, input, output); |
78 | 82 | }
|
79 | 83 |
|
80 |
| -int mmdeploy_classifier_get_result(mmdeploy_value_t output, mmdeploy_classification_t** results, |
81 |
| - int** result_count) { |
82 |
| - if (!output || !results || !result_count) { |
83 |
| - return MMDEPLOY_E_INVALID_ARG; |
84 |
| - } |
85 |
| - try { |
86 |
| - Value& value = Cast(output)->front(); |
87 |
| - |
88 |
| - auto classify_outputs = from_value<vector<mmcls::Labels>>(value); |
89 |
| - |
90 |
| - vector<int> _result_count; |
91 |
| - _result_count.reserve(classify_outputs.size()); |
92 |
| - |
93 |
| - for (const auto& cls_output : classify_outputs) { |
94 |
| - _result_count.push_back((int)cls_output.size()); |
| 84 | +int mmdeploy_classifier_get_result(mmdeploy_value_t output, mmdeploy_classification_t** results, int** result_count) |
| 85 | +{ |
| 86 | + if (!output || !results || !result_count) |
| 87 | + { |
| 88 | + return MMDEPLOY_E_INVALID_ARG; |
95 | 89 | }
|
96 |
| - |
97 |
| - auto total = std::accumulate(begin(_result_count), end(_result_count), 0); |
98 |
| - |
99 |
| - std::unique_ptr<int[]> result_count_data(new int[_result_count.size()]{}); |
100 |
| - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); |
101 |
| - |
102 |
| - std::unique_ptr<mmdeploy_classification_t[]> result_data( |
103 |
| - new mmdeploy_classification_t[total]{}); |
104 |
| - auto result_ptr = result_data.get(); |
105 |
| - for (const auto& cls_output : classify_outputs) { |
106 |
| - for (const auto& label : cls_output) { |
107 |
| - result_ptr->label_id = label.label_id; |
108 |
| - result_ptr->score = label.score; |
109 |
| - ++result_ptr; |
110 |
| - } |
| 90 | + try |
| 91 | + { |
| 92 | + Value& value = Cast(output)->front(); |
| 93 | + |
| 94 | + auto classify_outputs = from_value<vector<mmcls::Labels>>(value); |
| 95 | + |
| 96 | + vector<int> _result_count; |
| 97 | + _result_count.reserve(classify_outputs.size()); |
| 98 | + |
| 99 | + for (const auto& cls_output : classify_outputs) |
| 100 | + { |
| 101 | + _result_count.push_back((int)cls_output.size()); |
| 102 | + } |
| 103 | + |
| 104 | + auto total = std::accumulate(begin(_result_count), end(_result_count), 0); |
| 105 | + |
| 106 | + std::unique_ptr<int[]> result_count_data(new int[_result_count.size()]{}); |
| 107 | + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); |
| 108 | + |
| 109 | + std::unique_ptr<mmdeploy_classification_t[]> result_data( |
| 110 | + new mmdeploy_classification_t[total]{}); |
| 111 | + auto result_ptr = result_data.get(); |
| 112 | + for (const auto& cls_output : classify_outputs) |
| 113 | + { |
| 114 | + for (const auto& label : cls_output) |
| 115 | + { |
| 116 | + result_ptr->label_id = label.label_id; |
| 117 | + result_ptr->score = label.score; |
| 118 | + ++result_ptr; |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + *result_count = result_count_data.release(); |
| 123 | + *results = result_data.release(); |
| 124 | + |
| 125 | + return MMDEPLOY_SUCCESS; |
111 | 126 | }
|
112 |
| - |
113 |
| - *result_count = result_count_data.release(); |
114 |
| - *results = result_data.release(); |
115 |
| - |
116 |
| - return MMDEPLOY_SUCCESS; |
117 |
| - } catch (const std::exception& e) { |
118 |
| - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); |
119 |
| - } catch (...) { |
120 |
| - MMDEPLOY_ERROR("unknown exception caught"); |
121 |
| - } |
122 |
| - return MMDEPLOY_E_FAIL; |
| 127 | + catch (const std::exception& e) |
| 128 | + { |
| 129 | + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); |
| 130 | + } |
| 131 | + catch (...) |
| 132 | + { |
| 133 | + MMDEPLOY_ERROR("unknown exception caught"); |
| 134 | + } |
| 135 | + return MMDEPLOY_E_FAIL; |
123 | 136 | }
|
124 | 137 |
|
125 |
| -void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, const int* result_count, |
126 |
| - int count) { |
127 |
| - delete[] results; |
128 |
| - delete[] result_count; |
| 138 | +void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, const int* result_count, int count) |
| 139 | +{ |
| 140 | + delete[] results; |
| 141 | + delete[] result_count; |
129 | 142 | }
|
130 | 143 |
|
131 |
| -void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier) { |
132 |
| - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)classifier); |
| 144 | +void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier) |
| 145 | +{ |
| 146 | + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)classifier); |
133 | 147 | }
|
0 commit comments