|
| 1 | +// Copyright (c) OpenMMLab. All rights reserved. |
| 2 | + |
| 3 | +#include <cctype> |
| 4 | +#include <opencv2/imgproc.hpp> |
| 5 | +#include <iostream> |
| 6 | +#include <fstream> |
| 7 | + |
| 8 | +#include "mmdeploy/core/device.h" |
| 9 | +#include "mmdeploy/core/registry.h" |
| 10 | +#include "mmdeploy/core/serialization.h" |
| 11 | +#include "mmdeploy/core/tensor.h" |
| 12 | +#include "mmdeploy/core/utils/device_utils.h" |
| 13 | +#include "mmdeploy/core/utils/formatter.h" |
| 14 | +#include "mmdeploy/core/value.h" |
| 15 | +#include "mmdeploy/experimental/module_adapter.h" |
| 16 | +#include "mmpose.h" |
| 17 | +#include "opencv_utils.h" |
| 18 | + |
| 19 | +namespace mmdeploy::mmpose { |
| 20 | + |
| 21 | +using std::string; |
| 22 | +using std::vector; |
| 23 | + |
| 24 | +class YOLOXPose : public MMPose { |
| 25 | + public: |
| 26 | + explicit YOLOXPose(const Value& config) : MMPose(config) { |
| 27 | + if (config.contains("params")) { |
| 28 | + auto& params = config["params"]; |
| 29 | + if (params.contains("score_thr")) { |
| 30 | + from_value(params["score_thr"], score_thr_); |
| 31 | + } |
| 32 | + } |
| 33 | + } |
| 34 | + |
| 35 | + Result<Value> operator()(const Value& _data, const Value& _prob) { |
| 36 | + MMDEPLOY_DEBUG("preprocess_result: {}", _data); |
| 37 | + MMDEPLOY_DEBUG("inference_result: {}", _prob); |
| 38 | + |
| 39 | + Device cpu_device{"cpu"}; |
| 40 | + OUTCOME_TRY(auto dets, |
| 41 | + MakeAvailableOnDevice(_prob["dets"].get<Tensor>(), cpu_device, stream())); |
| 42 | + OUTCOME_TRY(auto keypoints, |
| 43 | + MakeAvailableOnDevice(_prob["keypoints"].get<Tensor>(), cpu_device, stream())); |
| 44 | + OUTCOME_TRY(stream().Wait()); |
| 45 | + if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) { |
| 46 | + MMDEPLOY_ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(), |
| 47 | + (int)dets.data_type()); |
| 48 | + return Status(eNotSupported); |
| 49 | + } |
| 50 | + if (!(keypoints.shape().size() == 4 && keypoints.data_type() == DataType::kFLOAT)) { |
| 51 | + MMDEPLOY_ERROR("unsupported `keypoints` tensor, shape: {}, dtype: {}", keypoints.shape(), |
| 52 | + (int)keypoints.data_type()); |
| 53 | + return Status(eNotSupported); |
| 54 | + } |
| 55 | + auto& img_metas = _data["img_metas"]; |
| 56 | + vector<float> scale_factor; |
| 57 | + if (img_metas.contains("scale_factor")) { |
| 58 | + from_value(img_metas["scale_factor"], scale_factor); |
| 59 | + } else { |
| 60 | + scale_factor = {1.f, 1.f, 1.f, 1.f}; |
| 61 | + } |
| 62 | + PoseDetectorOutput output; |
| 63 | + |
| 64 | + float* keypoints_data = keypoints.data<float>(); |
| 65 | + float* dets_data = dets.data<float>(); |
| 66 | + int num_dets = dets.shape(1), num_pts = keypoints.shape(2); |
| 67 | + float s = 0, x1=0, y1=0, x2=0, y2=0; |
| 68 | + |
| 69 | + // fprintf(stdout, "num_dets= %d num_pts = %d\n", num_dets, num_pts); |
| 70 | + for (int i = 0; i < dets.shape(0) * num_dets; i++){ |
| 71 | + x1 = (*(dets_data++)) / scale_factor[0]; |
| 72 | + y1 = (*(dets_data++)) / scale_factor[1]; |
| 73 | + x2 = (*(dets_data++)) / scale_factor[2]; |
| 74 | + y2 = (*(dets_data++)) / scale_factor[3]; |
| 75 | + s = *(dets_data++); |
| 76 | + // fprintf(stdout, "box %.2f %.2f %.2f %.2f %.6f\n", i, x1,y1,x2,y2,s); |
| 77 | + |
| 78 | + if (s <= score_thr_) { |
| 79 | + keypoints_data += num_pts * 3; |
| 80 | + continue; |
| 81 | + } |
| 82 | + output.detections.push_back({{x1, y1, x2, y2}, s}); |
| 83 | + for (int k = 0; k < num_pts; k++) { |
| 84 | + x1 = (*(keypoints_data++)) / scale_factor[0]; |
| 85 | + y1 = (*(keypoints_data++)) / scale_factor[1]; |
| 86 | + s = *(keypoints_data++); |
| 87 | + // fprintf(stdout, "point %d, index %d, %.2f %.2f %.6f\n", k, x1, y1, s); |
| 88 | + output.key_points.push_back({{x1, y1}, s}); |
| 89 | + } |
| 90 | + } |
| 91 | + return to_value(output); |
| 92 | + } |
| 93 | + |
| 94 | + protected: |
| 95 | + float score_thr_ = 0.001; |
| 96 | + |
| 97 | +}; |
| 98 | + |
| 99 | +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMPose, YOLOXPose); |
| 100 | + |
| 101 | +} // namespace mmdeploy::mmpose |
0 commit comments