diff --git a/csrc/mmdeploy/CMakeLists.txt b/csrc/mmdeploy/CMakeLists.txt index 6bfbd3a95a..4d05e7ee79 100644 --- a/csrc/mmdeploy/CMakeLists.txt +++ b/csrc/mmdeploy/CMakeLists.txt @@ -18,4 +18,8 @@ if (MMDEPLOY_BUILD_SDK) add_subdirectory(net) add_subdirectory(codebase) add_subdirectory(apis) + + if (TRITON_MMDEPLOY_BACKEND) + add_subdirectory(triton) + endif () endif () diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp index e20ec6a224..806a420fd5 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp @@ -11,27 +11,28 @@ namespace mmdeploy { namespace cxx { -class Pipeline : public NonMovable { +class Pipeline : public UniqueHandle { public: Pipeline(const Value& config, const Context& context) { - mmdeploy_pipeline_t pipeline{}; - auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline); + auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &handle_); if (ec != MMDEPLOY_SUCCESS) { throw_exception(static_cast(ec)); } - pipeline_ = pipeline; } ~Pipeline() { - if (pipeline_) { - mmdeploy_pipeline_destroy(pipeline_); - pipeline_ = nullptr; + if (handle_) { + mmdeploy_pipeline_destroy(handle_); + handle_ = nullptr; } } + Pipeline(Pipeline&&) noexcept = default; + Pipeline& operator=(Pipeline&&) noexcept = default; + Value Apply(const Value& inputs) { mmdeploy_value_t tmp{}; - auto ec = mmdeploy_pipeline_apply(pipeline_, (mmdeploy_value_t)&inputs, &tmp); + auto ec = mmdeploy_pipeline_apply(handle_, (mmdeploy_value_t)&inputs, &tmp); if (ec != MMDEPLOY_SUCCESS) { throw_exception(static_cast(ec)); } @@ -50,7 +51,7 @@ class Pipeline : public NonMovable { if (ec != MMDEPLOY_SUCCESS) { throw_exception(static_cast(ec)); } - auto outputs = Apply(*reinterpret_cast(inputs)); + auto outputs = this->Apply(*reinterpret_cast(inputs)); mmdeploy_value_destroy(inputs); return outputs; @@ -65,9 +66,6 @@ class Pipeline : public NonMovable { } return rets; } - - private: - mmdeploy_pipeline_t pipeline_{}; }; } // namespace cxx diff --git a/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp b/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp index 27dc6578b5..59df608b59 100644 --- a/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp +++ b/csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp @@ -64,9 +64,6 @@ static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); } Result RTMDetSepBNHead::GetBBoxes(const Value& prep_res, const std::vector& bbox_preds, const std::vector& cls_scores) const { - MMDEPLOY_DEBUG("bbox_pred: {}, {}", bbox_preds[0].shape(), dets[0].data_type()); - MMDEPLOY_DEBUG("cls_score: {}, {}", scores[0].shape(), scores[0].data_type()); - std::vector filter_boxes; std::vector obj_probs; std::vector class_ids; diff --git a/csrc/mmdeploy/codebase/mmseg/segment.cpp b/csrc/mmdeploy/codebase/mmseg/segment.cpp index 56811a4fad..6f96f99f27 100644 --- a/csrc/mmdeploy/codebase/mmseg/segment.cpp +++ b/csrc/mmdeploy/codebase/mmseg/segment.cpp @@ -90,6 +90,7 @@ class ResizeMask : public MMSegmentation { std::vector axes = {0, 3, 1, 2}; ::mmdeploy::operation::Context ctx(host, stream_); OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes)); + tensor_score.Squeeze(0); } SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_}; diff --git a/csrc/mmdeploy/core/logger.h b/csrc/mmdeploy/core/logger.h index 73de4f0ee1..5d8947edf0 100644 --- a/csrc/mmdeploy/core/logger.h +++ b/csrc/mmdeploy/core/logger.h @@ -47,7 +47,7 @@ MMDEPLOY_API void SetLogger(spdlog::logger *logger); #endif #ifdef SPDLOG_LOGGER_CALL -#define MMDEPLOY_LOG(level, ...) SPDLOG_LOGGER_CALL(mmdeploy::GetLogger(), level, __VA_ARGS__) +#define MMDEPLOY_LOG(level, ...) SPDLOG_LOGGER_CALL(::mmdeploy::GetLogger(), level, __VA_ARGS__) #else #define MMDEPLOY_LOG(level, ...) mmdeploy::GetLogger()->log(level, __VA_ARGS__) #endif diff --git a/csrc/mmdeploy/core/model.h b/csrc/mmdeploy/core/model.h index fcb396d267..73a575f3d3 100644 --- a/csrc/mmdeploy/core/model.h +++ b/csrc/mmdeploy/core/model.h @@ -30,8 +30,9 @@ struct model_meta_info_t { struct deploy_meta_info_t { std::string version; + std::string task; std::vector models; - MMDEPLOY_ARCHIVE_MEMBERS(version, models); + MMDEPLOY_ARCHIVE_MEMBERS(version, task, models); }; class ModelImpl; diff --git a/csrc/mmdeploy/device/cuda/cuda_device.h b/csrc/mmdeploy/device/cuda/cuda_device.h index 20b894652d..55d9b439f0 100644 --- a/csrc/mmdeploy/device/cuda/cuda_device.h +++ b/csrc/mmdeploy/device/cuda/cuda_device.h @@ -1,5 +1,8 @@ // Copyright (c) OpenMMLab. All rights reserved. +#ifndef MMDEPLOY_SRC_DEVICE_CUDA_CUDE_DEVICE_H_ +#define MMDEPLOY_SRC_DEVICE_CUDA_CUDE_DEVICE_H_ + #include #include @@ -196,3 +199,5 @@ class CudaDeviceGuard { }; } // namespace mmdeploy::framework + +#endif // MMDEPLOY_SRC_DEVICE_CUDA_CUDE_DEVICE_H_ diff --git a/csrc/mmdeploy/graph/inference.cpp b/csrc/mmdeploy/graph/inference.cpp index 8f5c8d1699..de9f632c60 100644 --- a/csrc/mmdeploy/graph/inference.cpp +++ b/csrc/mmdeploy/graph/inference.cpp @@ -14,24 +14,30 @@ using namespace framework; InferenceBuilder::InferenceBuilder(Value config) : Builder(std::move(config)) {} Result> InferenceBuilder::BuildImpl() { - auto& model_config = config_["params"]["model"]; - Model model; - if (model_config.is_any()) { - model = model_config.get(); - } else { - auto model_name = model_config.get(); - if (auto m = Maybe{config_} / "context" / "model" / model_name / identity{}) { - model = *m; + Value pipeline_config; + auto context = config_.value("context", Value(ValueType::kObject)); + const auto& params = config_["params"]; + if (params.contains("model")) { + auto& model_config = params["model"]; + Model model; + if (model_config.is_any()) { + model = model_config.get(); } else { - model = Model(model_name); + auto model_name = model_config.get(); + if (auto m = Maybe{config_} / "context" / "model" / model_name / identity{}) { + model = *m; + } else { + model = Model(model_name); + } } + OUTCOME_TRY(pipeline_config, model.ReadConfig("pipeline.json")); + context["model"] = std::move(model); + } else if (params.contains("pipeline")) { + assert(context.contains("model")); + auto model = context["model"].get(); + OUTCOME_TRY(pipeline_config, model.ReadConfig(params["pipeline"].get())); } - OUTCOME_TRY(auto pipeline_config, model.ReadConfig("pipeline.json")); - - auto context = config_.value("context", Value(ValueType::kObject)); - context["model"] = std::move(model); - if (context.contains("scope")) { auto name = config_.value("name", config_["type"].get()); auto scope = context["scope"].get_ref()->CreateScope(name); diff --git a/csrc/mmdeploy/graph/task.cpp b/csrc/mmdeploy/graph/task.cpp index 6cb6c4a798..6f876348d6 100644 --- a/csrc/mmdeploy/graph/task.cpp +++ b/csrc/mmdeploy/graph/task.cpp @@ -96,6 +96,7 @@ Result> TaskBuilder::BuildImpl() { task->is_thread_safe_ = config_.value("is_thread_safe", false); return std::move(task); } catch (const std::exception& e) { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); MMDEPLOY_ERROR("error parsing config: {}", config_); return nullptr; } diff --git a/csrc/mmdeploy/preprocess/transform_module.cpp b/csrc/mmdeploy/preprocess/transform_module.cpp index b718843ea8..269c170edf 100644 --- a/csrc/mmdeploy/preprocess/transform_module.cpp +++ b/csrc/mmdeploy/preprocess/transform_module.cpp @@ -10,46 +10,112 @@ namespace mmdeploy { class TransformModule { public: - ~TransformModule(); - TransformModule(TransformModule&&) noexcept; + ~TransformModule() = default; + TransformModule(TransformModule&&) noexcept = default; - explicit TransformModule(const Value& args); - Result operator()(const Value& input); + explicit TransformModule(const Value& args) { + const auto type = "Compose"; + auto creator = gRegistry().Get(type); + if (!creator) { + MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type, + gRegistry().List()); + throw_exception(eEntryNotFound); + } + auto cfg = args; + if (cfg.contains("device")) { + MMDEPLOY_WARN("force using device: {}", cfg["device"].get()); + auto device = Device(cfg["device"].get()); + cfg["context"]["device"] = device; + cfg["context"]["stream"] = Stream::GetDefault(device); + } + transform_ = creator->Create(cfg); + } + + Result operator()(const Value& input) { + auto data = input; + OUTCOME_TRY(transform_->Apply(data)); + return data; + } private: std::unique_ptr transform_; }; -TransformModule::~TransformModule() = default; +MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Transform, 0), [](const Value& config) { + return CreateTask(TransformModule{config}); +}); -TransformModule::TransformModule(TransformModule&&) noexcept = default; +#if 0 +class Preload { + public: + explicit Preload(const Value& args) { + const auto type = "Compose"; + auto creator = gRegistry().Get(type); + if (!creator) { + MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type, + gRegistry().List()); + throw_exception(eEntryNotFound); + } + auto cfg = args; + if (cfg.contains("device")) { + MMDEPLOY_WARN("force using device: {}", cfg["device"].get()); + auto device = Device(cfg["device"].get()); + cfg["context"]["device"] = device; + cfg["context"]["stream"] = Stream::GetDefault(device); + } + const auto& ctx = cfg["context"]; + ctx["device"].get_to(device_); + ctx["stream"].get_to(stream_); + } -TransformModule::TransformModule(const Value& args) { - const auto type = "Compose"; - auto creator = gRegistry().Get(type); - if (!creator) { - MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type, - gRegistry().List()); - throw_exception(eEntryNotFound); + Result operator()(const Value& input) { + auto data = input; + if (device_.is_device()) { + bool need_sync = false; + OUTCOME_TRY(Process(data, need_sync)); + MMDEPLOY_ERROR("need_sync = {}", need_sync); + MMDEPLOY_ERROR("{}", data); + if (need_sync) { + OUTCOME_TRY(stream_.Wait()); + } + } + return data; } - auto cfg = args; - if (cfg.contains("device")) { - MMDEPLOY_WARN("force using device: {}", cfg["device"].get()); - auto device = Device(cfg["device"].get()); - cfg["context"]["device"] = device; - cfg["context"]["stream"] = Stream::GetDefault(device); + + Result Process(Value& item, bool& need_sync) { + if (item.is_any()) { + auto& mat = item.get_ref(); + if (mat.device().is_host()) { + Mat tmp(mat.height(), mat.width(), mat.pixel_format(), mat.type(), device_); + OUTCOME_TRY(stream_.Copy(mat.buffer(), tmp.buffer(), mat.byte_size())); + mat = tmp; + need_sync |= true; + } + } else if (item.is_any()) { + auto& ten = item.get_ref(); + if (ten.device().is_host()) { + TensorDesc desc = ten.desc(); + desc.device = device_; + Tensor tmp(desc); + OUTCOME_TRY(stream_.Copy(ten.buffer(), tmp.buffer(), ten.byte_size())); + ten = tmp; + need_sync |= true; + } + } else if (item.is_array() || item.is_object()) { + for (auto& child : item) { + OUTCOME_TRY(Process(child, need_sync)); + } + } + return success(); } - transform_ = creator->Create(cfg); -} -Result TransformModule::operator()(const Value& input) { - auto data = input; - OUTCOME_TRY(transform_->Apply(data)); - return data; -} + private: + Device device_; + Stream stream_; +}; -MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Transform, 0), [](const Value& config) { - return CreateTask(TransformModule{config}); -}); +MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Preload, 0), + [](const Value& config) { return CreateTask(Preload{config}); }); +#endif } // namespace mmdeploy diff --git a/csrc/mmdeploy/triton/CMakeLists.txt b/csrc/mmdeploy/triton/CMakeLists.txt new file mode 100644 index 0000000000..c77b8610d9 --- /dev/null +++ b/csrc/mmdeploy/triton/CMakeLists.txt @@ -0,0 +1,120 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17) + +project(tritonmmdeploybackend LANGUAGES C CXX) + +# +# Options +# +# Must include options required for this project as well as any +# projects included in this one by FetchContent. +# +# GPU support is disabled by default because recommended backend +# doesn't use GPUs. +# + +if (NOT TRITON_TAG) + set(TRITON_TAG main) +endif() + +set(TRITON_COMMON_REPO_TAG ${TRITON_TAG} CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG ${TRITON_TAG} CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG ${TRITON_TAG} CACHE STRING "Tag for triton-inference-server/backend repo") + + +# +# Dependencies +# +# FetchContent requires us to include the transitive closure of all +# repos that we depend on so that we can override the tags. +# +include(FetchContent) + +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend) + +add_library(triton-mmdeploy-backend SHARED + model_state.cpp + instance_state.cpp + convert.cpp + json_input.cpp + mmdeploy.cpp) + +target_include_directories(triton-mmdeploy-backend PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + +target_compile_options( + triton-mmdeploy-backend PRIVATE + $<$,$,$>: + -Wall -Wno-unused-parameter -Wno-type-limits -Werror> + $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc> +) + +target_link_libraries( + triton-mmdeploy-backend + PRIVATE + triton-core-serverapi # from repo-core + triton-core-backendapi # from repo-core + triton-core-serverstub # from repo-core + triton-backend-utils # from repo-backend +) + +mmdeploy_load_static(triton-mmdeploy-backend MMDeployStaticModules) +target_link_libraries(triton-mmdeploy-backend PRIVATE MMDeployLibs) + +set_target_properties(triton-mmdeploy-backend PROPERTIES INSTALL_RPATH "\$ORIGIN") + +install(TARGETS triton-mmdeploy-backend DESTINATION backends/mmdeploy) + +if (WIN32) + set_target_properties( + triton-mmdeploy-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_mmdeploy + ) +else () + set_target_properties( + triton-mmdeploy-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_mmdeploy + ) +endif () diff --git a/csrc/mmdeploy/triton/convert.cpp b/csrc/mmdeploy/triton/convert.cpp new file mode 100644 index 0000000000..e558b7c41b --- /dev/null +++ b/csrc/mmdeploy/triton/convert.cpp @@ -0,0 +1,341 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "convert.h" + +#include + +#include "mmdeploy/archive/json_archive.h" +#include "mmdeploy/archive/value_archive.h" +#include "mmdeploy/codebase/mmaction/mmaction.h" +#include "mmdeploy/codebase/mmcls/mmcls.h" +#include "mmdeploy/codebase/mmdet/mmdet.h" +#include "mmdeploy/codebase/mmedit/mmedit.h" +#include "mmdeploy/codebase/mmocr/mmocr.h" +#include "mmdeploy/codebase/mmpose/mmpose.h" +#include "mmdeploy/codebase/mmrotate/mmrotate.h" +#include "mmdeploy/codebase/mmseg/mmseg.h" +#include "mmdeploy/core/utils/formatter.h" +#include "triton/backend/backend_common.h" + +namespace mmdeploy { + +namespace core = framework; + +core::Tensor Mat2Tensor(core::Mat mat) { + TensorDesc desc{mat.device(), mat.type(), {mat.height(), mat.width(), mat.channel()}, ""}; + return {desc, mat.buffer()}; +} + +} // namespace mmdeploy + +namespace triton::backend::mmdeploy { + +using Value = ::mmdeploy::Value; +using Tensor = ::mmdeploy::core::Tensor; +using TensorDesc = ::mmdeploy::core::TensorDesc; + +void ConvertClassifications(const Value& item, std::vector& tensors) { + ::mmdeploy::mmcls::Labels classify_outputs; + ::mmdeploy::from_value(item, classify_outputs); + Tensor labels(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(classify_outputs.size())}, + "labels"}); + Tensor scores(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(classify_outputs.size())}, + "scores"}); + auto labels_data = labels.data(); + auto scores_data = scores.data(); + for (const auto& c : classify_outputs) { + *labels_data++ = c.label_id; + *scores_data++ = c.score; + } + tensors.push_back(std::move(labels)); + tensors.push_back(std::move(scores)); +} + +void ConvertDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmdet::Detections detections; + ::mmdeploy::from_value(item, detections); + Tensor bboxes(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.size()), 5}, + "bboxes"}); + Tensor labels(TensorDesc{bboxes.device(), + ::mmdeploy::DataType::kINT32, + {static_cast(detections.size())}, + "labels"}); + auto bboxes_data = bboxes.data(); + auto labels_data = labels.data(); + int64_t sum_byte_size = 0; + for (const auto& det : detections) { + for (const auto& x : det.bbox) { + *bboxes_data++ = x; + } + *bboxes_data++ = det.score; + *labels_data++ = det.label_id; + sum_byte_size += det.mask.byte_size(); + } + tensors.push_back(std::move(bboxes)); + tensors.push_back(std::move(labels)); + if (sum_byte_size > 0) { + // return mask + Tensor masks(TensorDesc{bboxes.device(), + ::mmdeploy::DataType::kINT8, + {static_cast(sum_byte_size)}, + "masks"}); + Tensor offs(TensorDesc{bboxes.device(), + ::mmdeploy::DataType::kINT32, + {static_cast(detections.size()), 3}, + "mask_offs"}); // [(off, w, h), ... ] + + auto masks_data = masks.data(); + auto offs_data = offs.data(); + int sum_offs = 0; + for (const auto& det : detections) { + memcpy(masks_data, det.mask.data(), det.mask.byte_size()); + masks_data += det.mask.byte_size(); + *offs_data++ = sum_offs; + *offs_data++ = det.mask.width(); + *offs_data++ = det.mask.height(); + sum_offs += det.mask.byte_size(); + } + tensors.push_back(std::move(masks)); + tensors.push_back(std::move(offs)); + } +} + +void ConvertSegmentation(const Value& item, std::vector& tensors) { + ::mmdeploy::mmseg::SegmentorOutput seg; + ::mmdeploy::from_value(item, seg); + if (seg.score.size()) { + auto desc = seg.score.desc(); + desc.name = "score"; + tensors.emplace_back(desc, seg.score.buffer()); + } + if (seg.mask.size()) { + auto desc = seg.mask.desc(); + desc.name = "mask"; + tensors.emplace_back(desc, seg.mask.buffer()); + tensors.back().Squeeze(); + } +} + +void ConvertMats(const Value& item, std::vector& tensors) { + ::mmdeploy::mmedit::RestorerOutput restoration; + ::mmdeploy::from_value(item, restoration); + tensors.push_back(::mmdeploy::Mat2Tensor(restoration)); +} + +void ConvertTextDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmocr::TextDetections detections; + ::mmdeploy::from_value(item, detections); + Tensor bboxes(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.size()), 8}, + "bboxes"}); + Tensor scores(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.size()), 1}, + "scores"}); + auto bboxes_data = bboxes.data(); + auto scores_data = scores.data(); + for (const auto& det : detections) { + bboxes_data = std::copy(det.bbox.begin(), det.bbox.end(), bboxes_data); + *scores_data++ = det.score; + } + tensors.push_back(std::move(bboxes)); + tensors.push_back(std::move(scores)); +} + +void ConvertTextRecognitions(const Value& item, int request_count, + const std::vector& batch_per_request, + std::vector>& tensors, + std::vector& strings) { + std::vector<::mmdeploy::mmocr::TextRecognition> recognitions; + ::mmdeploy::from_value(item, recognitions); + + int k = 0; + for (int i = 0; i < request_count; i++) { + int num = batch_per_request[i]; + Tensor texts(TensorDesc{ + ::mmdeploy::Device(0), ::mmdeploy::DataType::kINT32, {static_cast(num)}, "texts"}); + Tensor score(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(num)}, + "scores"}); + auto text_data = texts.data(); + auto score_data = score.data(); + + for (int j = 0; j < num; j++) { + auto& recognition = recognitions[k++]; + text_data[j] = static_cast(strings.size()); + strings.push_back(recognition.text); + score_data[j] = static_cast(strings.size()); + strings.push_back(::mmdeploy::to_json(::mmdeploy::to_value(recognition.score)).dump()); + } + tensors[i].push_back(std::move(texts)); + tensors[i].push_back(std::move(score)); + } +} + +void ConvertTextRecognitions(const Value& item, std::vector& tensors, + std::vector& strings) { + std::vector<::mmdeploy::mmocr::TextRecognition> recognitions; + ::mmdeploy::from_value(item, recognitions); + + Tensor texts(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(recognitions.size())}, + "rec_texts"}); + Tensor score(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(recognitions.size())}, + "rec_scores"}); + auto text_data = texts.data(); + auto score_data = score.data(); + + for (size_t j = 0; j < recognitions.size(); j++) { + auto& recognition = recognitions[j]; + text_data[j] = static_cast(strings.size()); + strings.push_back(recognition.text); + score_data[j] = static_cast(strings.size()); + strings.push_back(::mmdeploy::to_json(::mmdeploy::to_value(recognition.score)).dump()); + } + tensors.push_back(std::move(texts)); + tensors.push_back(std::move(score)); +} + +void ConvertPreprocess(const Value& item, std::vector& tensors, + std::vector& strings) { + Value::Object img_metas; + for (auto it = item.begin(); it != item.end(); ++it) { + if (it->is_any()) { + auto tensor = it->get(); + auto desc = tensor.desc(); + desc.name = it.key(); + tensors.emplace_back(desc, tensor.buffer()); + } else if (!it->is_any<::mmdeploy::framework::Mat>()) { + img_metas.insert({it.key(), *it}); + } + } + auto index = static_cast(strings.size()); + strings.push_back(::mmdeploy::format_value(img_metas)); + Tensor img_meta_tensor( + TensorDesc{::mmdeploy::Device(0), ::mmdeploy::DataType::kINT32, {1}, "img_metas"}); + *img_meta_tensor.data() = index; + tensors.push_back(std::move(img_meta_tensor)); +} + +void ConvertInference(const Value& item, std::vector& tensors) { + for (auto it = item.begin(); it != item.end(); ++it) { + auto tensor = it->get(); + auto desc = tensor.desc(); + desc.name = it.key(); + tensors.emplace_back(desc, tensor.buffer()); + } +} + +void ConvertPoseDetections(const Value& item, int request_count, + const std::vector& batch_per_request, + std::vector>& tensors) { + std::vector<::mmdeploy::mmpose::PoseDetectorOutput> detections; + ::mmdeploy::from_value(item, detections); + + int k = 0; + for (int i = 0; i < request_count; i++) { + int num = batch_per_request[i]; + Tensor pts(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {num, static_cast(detections[0].key_points.size()), 3}, + "keypoints"}); + auto pts_data = pts.data(); + for (int j = 0; j < num; j++) { + auto& detection = detections[k++]; + for (const auto& p : detection.key_points) { + *pts_data++ = p.bbox[0]; + *pts_data++ = p.bbox[1]; + *pts_data++ = p.score; + } + } + tensors[i].push_back(std::move(pts)); + } +} + +void ConvertRotatedDetections(const Value& item, std::vector& tensors) { + ::mmdeploy::mmrotate::RotatedDetectorOutput detections; + ::mmdeploy::from_value(item, detections); + Tensor bboxes(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kFLOAT, + {static_cast(detections.detections.size()), 6}, + "bboxes"}); + Tensor labels(TensorDesc{::mmdeploy::Device(0), + ::mmdeploy::DataType::kINT32, + {static_cast(detections.detections.size())}, + "labels"}); + auto bboxes_data = bboxes.data(); + auto labels_data = labels.data(); + for (const auto& det : detections.detections) { + bboxes_data = std::copy(det.rbbox.begin(), det.rbbox.end(), bboxes_data); + *bboxes_data++ = det.score; + *labels_data++ = det.label_id; + } + tensors.push_back(std::move(bboxes)); + tensors.push_back(std::move(labels)); +} + +std::vector> ConvertOutputToTensors(const std::string& type, + int32_t request_count, + const std::vector& batch_per_request, + const Value& output, + std::vector& strings) { + std::vector> tensors(request_count); + if (type == "Preprocess") { + for (int i = 0; i < request_count; ++i) { + ConvertPreprocess(output.front()[i], tensors[i], strings); + } + } else if (type == "Inference") { + for (int i = 0; i < request_count; ++i) { + ConvertInference(output.front()[i], tensors[i]); + } + } else if (type == "Classifier") { + for (int i = 0; i < request_count; ++i) { + ConvertClassifications(output.front()[i], tensors[i]); + } + } else if (type == "Detector") { + for (int i = 0; i < request_count; ++i) { + ConvertDetections(output.front()[i], tensors[i]); + } + } else if (type == "Segmentor") { + for (int i = 0; i < request_count; ++i) { + ConvertSegmentation(output.front()[i], tensors[i]); + } + } else if (type == "Restorer") { + for (int i = 0; i < request_count; ++i) { + ConvertMats(output.front()[i], tensors[i]); + } + } else if (type == "TextDetector") { + for (int i = 0; i < request_count; ++i) { + ConvertTextDetections(output.front()[i], tensors[i]); + } + } else if (type == "TextRecognizer") { + ConvertTextRecognitions(output.front(), request_count, batch_per_request, tensors, strings); + } else if (type == "PoseDetector") { + ConvertPoseDetections(output.front(), request_count, batch_per_request, tensors); + } else if (type == "RotatedDetector") { + for (int i = 0; i < request_count; ++i) { + ConvertRotatedDetections(output.front()[i], tensors[i]); + } + } else if (type == "TextOCR") { + for (int i = 0; i < request_count; ++i) { + ConvertTextDetections(output[0][i], tensors[i]); + ConvertTextRecognitions(output[1][i], tensors[i], strings); + } + } else { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, ("Unsupported type: " + type).c_str()); + } + return tensors; +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/convert.h b/csrc/mmdeploy/triton/convert.h new file mode 100644 index 0000000000..0646b53d80 --- /dev/null +++ b/csrc/mmdeploy/triton/convert.h @@ -0,0 +1,19 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_CONVERT_H +#define MMDEPLOY_CONVERT_H + +#include + +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/value.h" + +namespace triton::backend::mmdeploy { + +std::vector> ConvertOutputToTensors( + const std::string& type, int32_t request_count, const std::vector& batch_per_request, + const ::mmdeploy::Value& output, std::vector& strings); + +} + +#endif // MMDEPLOY_CONVERT_H diff --git a/csrc/mmdeploy/triton/instance_state.cpp b/csrc/mmdeploy/triton/instance_state.cpp new file mode 100644 index 0000000000..ec8a9414a0 --- /dev/null +++ b/csrc/mmdeploy/triton/instance_state.cpp @@ -0,0 +1,477 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "instance_state.h" + +#include +#include + +#include "convert.h" +#include "json.hpp" +#include "json_input.h" +#include "mmdeploy/archive/json_archive.h" +#include "mmdeploy/core/device.h" +#include "mmdeploy/core/mat.h" +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy_utils.h" + +namespace triton::backend::mmdeploy { + +TRITONSERVER_Error* ModelInstanceState::Create(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) { + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} +ModelInstanceState::ModelInstanceState(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state), + pipeline_(model_state_->CreatePipeline(Kind(), DeviceId())) { + // parse parameters + ::triton::common::TritonJson::Value parameters; + model_state->ModelConfig().Find("parameters", ¶meters); + std::string info; + TryParseModelStringParameter(parameters, "merge_inputs", &info, ""); + if (info != "") { + std::stringstream ss1(info); + std::string group; + while (std::getline(ss1, group, ',')) { + std::stringstream ss2(group); + merge_inputs_.emplace_back(); + int v; + while (ss2 >> v) { + merge_inputs_.back().push_back(v); + } + } + } +} + +// TRITON DIR MMDeploy +// (Tensor, PixFmt, Region) -> (Mat , Region) +// [Tensor] <- ([Tensor], Meta ) +// [Tensor] -> ([Tensor], Meta ) +// [Tensor] <- [Value] + +TRITONSERVER_Error* ModelInstanceState::Execute(TRITONBACKEND_Request** requests, + uint32_t request_count) { + // Collect various timestamps during the execution of this batch or + // requests. These values are reported below before returning from + // the function. + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + ModelState* model_state = StateForModel(); + + const int max_batch_size = model_state->MaxBatchSize(); + + for (size_t i = 0; i < request_count; ++i) { + if (requests[i] == nullptr) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("null request given to MMDeploy backend for '" + Name() + "'").c_str())); + return nullptr; + } + + if (max_batch_size > 0) { + // Retrieve the batch size from one of the inputs, if the model + // supports batching, the first dimension size is batch size + // and batch dim should be 1 for mmdeploy + TRITONBACKEND_Input* input; + TRITONSERVER_Error* err = + TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); + if (err == nullptr) { + const int64_t* shape; + err = TRITONBACKEND_InputProperties(input, nullptr, nullptr, &shape, nullptr, nullptr, + nullptr); + if (err == nullptr && shape[0] != 1) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("only support batch dim 1 for single request").c_str()); + } + } + if (err != nullptr) { + RequestsRespondWithError(requests, request_count, err); + return nullptr; + } + } + } + + // 'responses' is initialized as a parallel array to 'requests', + // with one TRITONBACKEND_Response object for each + // TRITONBACKEND_Request object. If something goes wrong while + // creating these response objects, the backend simply returns an + // error from TRITONBACKEND_ModelInstanceExecute, indicating to + // Triton that this backend did not create or send any responses and + // so it is up to Triton to create and send an appropriate error + // response for each request. RETURN_IF_ERROR is one of several + // useful macros for error handling that can be found in + // backend_common.h. + + std::vector responses; + responses.reserve(request_count); + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, request)); + responses.push_back(response); + } + + std::vector> allowed_input_types = { + {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; + + std::vector> collectors(request_count); + std::vector> response_vecs(request_count); + bool need_cuda_input_sync = false; + + for (uint32_t request_index = 0; request_index < request_count; ++request_index) { + response_vecs[request_index] = {responses[request_index]}; + collectors[request_index] = std::make_unique( + &requests[request_index], 1, &response_vecs[request_index], + model_state->TritonMemoryManager(), false, CudaStream()); + } + + // Setting input data + ::mmdeploy::Value vec_inputs; + std::vector batch_per_request; + for (uint32_t request_index = 0; request_index < request_count; ++request_index) { + const auto& collector = collectors[request_index]; + ::mmdeploy::Value vec_inputi; + batch_per_request.push_back(1); + + for (size_t input_id = 0; input_id < model_state->input_names().size(); ++input_id) { + ::mmdeploy::Value inputi; + const auto& input_name = model_state->input_names()[input_id]; + // Get input shape + TRITONBACKEND_Input* input{}; + RETURN_IF_ERROR( + TRITONBACKEND_RequestInput(requests[request_index], input_name.c_str(), &input)); + TRITONSERVER_DataType data_type{}; + const int64_t* dims{}; + uint32_t dims_count{}; + RETURN_IF_ERROR(TRITONBACKEND_InputProperties(input, nullptr, &data_type, &dims, &dims_count, + nullptr, nullptr)); + if (data_type != TRITONSERVER_TYPE_BYTES) { + // Collect input buffer + const char* buffer{}; + size_t buffer_size{}; + TRITONSERVER_MemoryType memory_type{}; + int64_t memory_type_id{}; + RETURN_IF_ERROR(collector->ProcessTensor(input_name.c_str(), nullptr, 0, + allowed_input_types, &buffer, &buffer_size, + &memory_type, &memory_type_id)); + ::mmdeploy::framework::Device device(0); + if (memory_type == TRITONSERVER_MEMORY_GPU) { + device = ::mmdeploy::framework::Device("cuda", static_cast(memory_type_id)); + } + + if (model_state->input_formats()[input_id] == "FORMAT_NHWC") { + // Construct Mat from shape & buffer + int h, w; + if (max_batch_size > 0) { + h = dims[1]; + w = dims[2]; + } else { + h = dims[0]; + w = dims[1]; + } + ::mmdeploy::framework::Mat mat( + h, w, ::mmdeploy::PixelFormat::kBGR, ::mmdeploy::DataType::kINT8, + std::shared_ptr(const_cast(buffer), [](auto) {}), device); + inputi = {{input_name, mat}}; + } else { + ::mmdeploy::framework::Tensor tensor( + ::mmdeploy::framework::TensorDesc{ + device, ConvertDataType(model_state->input_data_types()[input_id]), + ::mmdeploy::framework::TensorShape(dims, dims + dims_count), input_name}, + std::shared_ptr(const_cast(buffer), [](auto) {})); + inputi = {{input_name, std::move(tensor)}}; + } + } else { + ::mmdeploy::Value value; + GetStringInputTensor(input, dims, dims_count, value); + assert(value.is_array()); + + if (value[0].contains("type")) { + const auto& type = value[0]["type"].get_ref(); + CreateJsonInput(value[0]["value"], type, inputi); + batch_per_request.back() = inputi.size(); + } else { + inputi = {{}}; + inputi.update(value.front().object()); + } + } + vec_inputi.push_back(std::move(inputi)); // [ a, [b,b] ] + } + + // broadcast, [ a, [b,b] ] -> [[a, a], [b, b]] + if (batch_per_request.back() >= 1) { + // std::vector<::mmdeploy::Value> input; + ::mmdeploy::Value input; + for (size_t i = 0; i < vec_inputi.size(); i++) { + input.push_back(::mmdeploy::Value::kArray); + } + + for (int i = 0; i < batch_per_request.back(); i++) { + for (size_t input_id = 0; input_id < model_state->input_names().size(); ++input_id) { + if (vec_inputi[input_id].is_object()) { + input[input_id].push_back(vec_inputi[input_id]); + } else { + input[input_id].push_back(vec_inputi[input_id][i]); + } + } + } + vec_inputi = input; + } + + // construct [[a,a,a], [b,b,b]] + if (vec_inputs.is_null()) { + for (size_t i = 0; i < vec_inputi.size(); i++) { + vec_inputs.push_back(::mmdeploy::Value::kArray); + } + } + for (size_t i = 0; i < vec_inputi.size(); i++) { + auto&& inner = vec_inputi[i]; + for (auto&& obj : inner) { + vec_inputs[i].push_back(std::move(obj)); + } + } + } + + // merge inputs for example: [[a,a,a], [b,b,b], [c,c,c]] -> [[aaa], [(b,c), (b,c), (b,c)]] + if (!merge_inputs_.empty()) { + int n_example = vec_inputs[0].size(); + ::mmdeploy::Value inputs; + for (const auto& group : merge_inputs_) { + ::mmdeploy::Value input_array; + for (int i = 0; i < n_example; i++) { + ::mmdeploy::Value input_i; + for (const auto& idx : group) { + auto&& inner = vec_inputs[idx]; + input_i.update(inner[i]); + } + input_array.push_back(std::move(input_i)); + } + inputs.push_back(std::move(input_array)); + } + vec_inputs = std::move(inputs); + } + + if (need_cuda_input_sync) { +#if TRITON_ENABLE_GPU + cudaStreamSynchronize(CudaStream()); +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + "mmdeploy backend: unexpected CUDA sync required by collector"); +#endif + } + + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + ::mmdeploy::Value outputs = pipeline_.Apply(vec_inputs); + // MMDEPLOY_ERROR("outputs:\n{}", outputs); + + // preprocess and inference need cuda sync + { + std::string device_name = "cpu"; + if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + device_name = "cuda"; + } + auto device = ::mmdeploy::framework::Device(device_name.c_str(), DeviceId()); + auto stream = ::mmdeploy::framework::Stream::GetDefault(device); + stream.Wait(); + } + + std::vector strings; + auto output_tensors = ConvertOutputToTensors(model_state->task_type(), request_count, + batch_per_request, outputs, strings); + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + std::vector> responders(request_count); + MMDEPLOY_DEBUG("request_count {}", request_count); + for (uint32_t request_index = 0; request_index < request_count; ++request_index) { + responders[request_index] = std::make_unique( + &requests[request_index], 1, &response_vecs[request_index], false, + model_state->TritonMemoryManager(), false, CudaStream()); + for (size_t output_id = 0; output_id < model_state->output_names().size(); ++output_id) { + auto output_name = model_state->output_names()[output_id]; + MMDEPLOY_DEBUG("output name {}", output_name); + auto output_data_type = model_state->output_data_types()[output_id]; + for (const auto& tensor : output_tensors[request_index]) { + if (tensor.name() == output_name) { + if (output_data_type != TRITONSERVER_TYPE_BYTES) { + auto shape = tensor.shape(); + MMDEPLOY_DEBUG("name {}, shape {}", tensor.name(), shape); + auto memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + if (not tensor.device().is_host()) { + memory_type = TRITONSERVER_MEMORY_GPU; + memory_type_id = tensor.device().device_id(); + } + responders[request_index]->ProcessTensor( + tensor.name(), ConvertDataType(tensor.data_type()), shape, tensor.data(), + memory_type, memory_type_id); + } else { + RETURN_IF_ERROR(SetStringOutputTensor(tensor, strings, responses[request_index])); + } + break; + } + } + } + + const bool need_cuda_output_sync = responders[request_index]->Finalize(); + if (need_cuda_output_sync) { +#if TRITON_ENABLE_GPU + cudaStreamSynchronize(CudaStream()); +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + "mmdeploy backend: unexpected CUDA sync required by responder"); +#endif + } + } + + // Send all the responses that haven't already been sent because of + // an earlier error. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send response"); + } + } + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + +#ifdef TRITON_ENABLE_STATS + // For batch statistics need to know the total batch size of the + // requests. This is not necessarily just the number of requests, + // because if the model supports batching then any request can be a + // batched request itself. + size_t total_batch_size = request_count; +#else + (void)exec_start_ns; + (void)exec_end_ns; + (void)compute_start_ns; + (void)compute_end_ns; +#endif // TRITON_ENABLE_STATS + + // Report statistics for each request, and then release the request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + +#ifdef TRITON_ENABLE_STATS + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics( + TritonModelInstance(), request, (responses[r] != nullptr) /* success */, + exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); +#endif // TRITON_ENABLE_STATS + + LOG_IF_ERROR(TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + +#ifdef TRITON_ENABLE_STATS + // Report batch statistics. + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics( + TritonModelInstance(), total_batch_size, exec_start_ns, compute_start_ns, + compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); +#endif + + return nullptr; // success +} + +TRITONSERVER_Error* ModelInstanceState::GetStringInputTensor(TRITONBACKEND_Input* input, + const int64_t* dims, + uint32_t dims_count, + ::mmdeploy::Value& value) { + ::mmdeploy::Value::Array array; + const char* buffer{}; + uint64_t buffer_byte_size{}; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id{}; + RETURN_IF_ERROR(TRITONBACKEND_InputBuffer(input, 0, reinterpret_cast(&buffer), + &buffer_byte_size, &memory_type, &memory_type_id)); + auto count = std::accumulate(dims, dims + dims_count, 1LL, std::multiplies<>{}); + size_t offset = 0; + for (int64_t i = 0; i < count; ++i) { + // read string length + if (offset + sizeof(uint32_t) > buffer_byte_size) { + break; + } + auto length = *reinterpret_cast(buffer + offset); + offset += sizeof(uint32_t); + // read string data + if (offset + length > buffer_byte_size) { + break; + } + std::string data(buffer + offset, buffer + offset + length); + offset += length; + // deserialize from json string + auto data_value = ::mmdeploy::from_json<::mmdeploy::Value>(nlohmann::json::parse(data)); + array.push_back(std::move(data_value)); + } + value = std::move(array); + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::SetStringOutputTensor( + const ::mmdeploy::framework::Tensor& tensor, const std::vector& strings, + TRITONBACKEND_Response* response) { + assert(tensor.data_type() == ::mmdeploy::DataType::kINT32); + TRITONSERVER_Error* err{}; + TRITONBACKEND_Output* response_output{}; + err = TRITONBACKEND_ResponseOutput(response, &response_output, tensor.name(), + TRITONSERVER_TYPE_BYTES, tensor.shape().data(), + tensor.shape().size()); + if (!err) { + size_t data_byte_size{}; + auto index_data = tensor.data(); + auto size = tensor.size(); + for (int64_t j = 0; j < size; ++j) { + data_byte_size += strings[index_data[j]].size(); + } + auto expected_byte_size = data_byte_size + sizeof(uint32_t) * size; + void* buffer{}; + TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU; + int64_t actual_memory_type_id = 0; + err = TRITONBACKEND_OutputBuffer(response_output, &buffer, expected_byte_size, + &actual_memory_type, &actual_memory_type_id); + if (!err) { + bool cuda_used = false; + size_t copied_byte_size = 0; + for (int64_t j = 0; j < size; ++j) { + auto len = static_cast(strings[index_data[j]].size()); + err = CopyBuffer(tensor.name(), TRITONSERVER_MEMORY_CPU, 0, actual_memory_type, + actual_memory_type_id, sizeof(uint32_t), &len, + static_cast(buffer) + copied_byte_size, nullptr, &cuda_used); + if (err) { + break; + } + copied_byte_size += sizeof(uint32_t); + err = CopyBuffer(tensor.name(), TRITONSERVER_MEMORY_CPU, 0, actual_memory_type, + actual_memory_type_id, len, strings[index_data[j]].data(), + static_cast(buffer) + copied_byte_size, nullptr, &cuda_used); + if (err) { + break; + } + copied_byte_size += len; + } + } + } + return err; +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/instance_state.h b/csrc/mmdeploy/triton/instance_state.h new file mode 100644 index 0000000000..7204a5237d --- /dev/null +++ b/csrc/mmdeploy/triton/instance_state.h @@ -0,0 +1,44 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_INSTANCE_STATE_H +#define MMDEPLOY_INSTANCE_STATE_H + +#include "mmdeploy/core/tensor.h" +#include "model_state.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" + +namespace triton::backend::mmdeploy { + +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + ~ModelInstanceState() override = default; + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + TRITONSERVER_Error* Execute(TRITONBACKEND_Request** requests, uint32_t request_count); + + TRITONSERVER_Error* GetStringInputTensor(TRITONBACKEND_Input* input, const int64_t* dims, + uint32_t dims_count, ::mmdeploy::Value& value); + + TRITONSERVER_Error* SetStringOutputTensor(const ::mmdeploy::framework::Tensor& tensor, + const std::vector& strings, + TRITONBACKEND_Response* response); + + private: + ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance); + + private: + ModelState* model_state_; + ::mmdeploy::Pipeline pipeline_; + std::vector> merge_inputs_; +}; + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_INSTANCE_STATE_H diff --git a/csrc/mmdeploy/triton/json_input.cpp b/csrc/mmdeploy/triton/json_input.cpp new file mode 100644 index 0000000000..9ed0803cd4 --- /dev/null +++ b/csrc/mmdeploy/triton/json_input.cpp @@ -0,0 +1,14 @@ +#include "json_input.h" + +namespace triton::backend::mmdeploy { + +void CreateJsonInput(::mmdeploy::Value &input, const std::string &type, ::mmdeploy::Value &output) { + if (type == "TextBbox") { + output = input; + } + if (type == "PoseBbox") { + output = input; + } +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/json_input.h b/csrc/mmdeploy/triton/json_input.h new file mode 100644 index 0000000000..cf6c67102d --- /dev/null +++ b/csrc/mmdeploy/triton/json_input.h @@ -0,0 +1,22 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_TRITON_JSON_INPUT_H +#define MMDEPLOY_TRITON_JSON_INPUT_H + +#include +#include + +#include "mmdeploy/archive/value_archive.h" + +namespace triton::backend::mmdeploy { + +struct TextBbox { + std::array bbox; + MMDEPLOY_ARCHIVE_MEMBERS(bbox); +}; + +void CreateJsonInput(::mmdeploy::Value &input, const std::string &type, ::mmdeploy::Value &output); + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_TRITON_JSON_INPUT_H diff --git a/csrc/mmdeploy/triton/mmdeploy.cpp b/csrc/mmdeploy/triton/mmdeploy.cpp new file mode 100644 index 0000000000..4d1f6ca3f5 --- /dev/null +++ b/csrc/mmdeploy/triton/mmdeploy.cpp @@ -0,0 +1,150 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "instance_state.h" +#include "mmdeploy/core/logger.h" +#include "model_state.h" +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/core/tritonbackend.h" + +namespace triton::backend::mmdeploy { + +extern "C" { + +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); + std::string name(cname); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); + + // Check the backend API version that Triton supports vs. what this + // backend was compiled against. Make sure that the Triton major + // version is the same and the minor version is >= what this backend + // uses. + uint32_t api_version_major, api_version_minor; + RETURN_IF_ERROR(TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + std::to_string(api_version_minor)) + .c_str()); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("'") + name + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + + if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || + (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, + "triton backend API version does not support this backend"); + } + + // The backend configuration may contain information needed by the + // backend, such as tritonserver command-line arguments. This + // backend doesn't use any such configuration but for this example + // print whatever is available. + TRITONSERVER_Message* backend_config_message; + RETURN_IF_ERROR(TRITONBACKEND_BackendConfig(backend, &backend_config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(backend_config_message, &buffer, &byte_size)); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("backend configuration:\n") + buffer).c_str()); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_Finalize when a backend is no longer +// needed. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) { + return nullptr; // success +} + +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) { + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + auto model_state = reinterpret_cast(vstate); + delete model_state; + + return nullptr; // success +} +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInstanceInitialize when a model +// instance is created to allow the backend to initialize any state +// associated with the instance. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize( + TRITONBACKEND_ModelInstance* instance) { + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceSetState(instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelInstanceFinalize when a model +// instance is no longer needed. The backend should cleanup any state +// associated with the model instance. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize( + TRITONBACKEND_ModelInstance* instance) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = reinterpret_cast(vstate); + delete instance_state; + + return nullptr; // success +} + +} // extern "C" + +extern "C" { + +// When Triton calls TRITONBACKEND_ModelInstanceExecute it is required +// that a backend create a response for each request in the batch. A +// response may be the output tensors required for that request or may +// be an error that is returned in the response. +// +MMDEPLOY_EXPORT TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) { + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Best practice for a high-performance + // implementation is to avoid introducing mutex/lock and instead use + // only function-local and model-instance-specific state. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast(&instance_state))); + return instance_state->Execute(requests, request_count); +} + +} // extern "C" + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/mmdeploy_utils.h b/csrc/mmdeploy/triton/mmdeploy_utils.h new file mode 100644 index 0000000000..2fad294ad7 --- /dev/null +++ b/csrc/mmdeploy/triton/mmdeploy_utils.h @@ -0,0 +1,48 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MMDEPLOY_UTILS_H +#define MMDEPLOY_MMDEPLOY_UTILS_H + +#include "mmdeploy/core/types.h" + +namespace triton::backend::mmdeploy { + +inline TRITONSERVER_DataType ConvertDataType(::mmdeploy::DataType data_type) { + using namespace ::mmdeploy::data_types; + switch (data_type) { + case kFLOAT: + return TRITONSERVER_TYPE_FP32; + case kHALF: + return TRITONSERVER_TYPE_FP16; + case kINT8: + return TRITONSERVER_TYPE_UINT8; + case kINT32: + return TRITONSERVER_TYPE_INT32; + case kINT64: + return TRITONSERVER_TYPE_INT64; + default: + return TRITONSERVER_TYPE_INVALID; + } +} + +inline ::mmdeploy::DataType ConvertDataType(TRITONSERVER_DataType data_type) { + using namespace ::mmdeploy::data_types; + switch (data_type) { + case TRITONSERVER_TYPE_FP32: + return kFLOAT; + case TRITONSERVER_TYPE_FP16: + return kHALF; + case TRITONSERVER_TYPE_UINT8: + return kINT8; + case TRITONSERVER_TYPE_INT32: + return kINT32; + case TRITONSERVER_TYPE_INT64: + return kINT64; + default: + return ::mmdeploy::DataType::kCOUNT; + } +} + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_MMDEPLOY_UTILS_H diff --git a/csrc/mmdeploy/triton/model_state.cpp b/csrc/mmdeploy/triton/model_state.cpp new file mode 100644 index 0000000000..26c34e496f --- /dev/null +++ b/csrc/mmdeploy/triton/model_state.cpp @@ -0,0 +1,121 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "model_state.h" + +#include + +#include "mmdeploy/archive/json_archive.h" +#include "mmdeploy/archive/value_archive.h" +#include "mmdeploy/core/utils/filesystem.h" +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/pipeline.hpp" + +namespace triton::backend::mmdeploy { + +TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) { + try { + *state = new ModelState(triton_model); + } catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return {}; +} + +ModelState::ModelState(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model) { + THROW_IF_BACKEND_MODEL_ERROR(ValidateModelConfig()); +} + +TRITONSERVER_Error* ModelState::ValidateModelConfig() { + common::TritonJson::Value inputs; + common::TritonJson::Value outputs; + RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs)); + RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs)); + + for (size_t i = 0; i < inputs.ArraySize(); ++i) { + common::TritonJson::Value input; + RETURN_IF_ERROR(inputs.IndexAsObject(i, &input)); + + triton::common::TritonJson::Value reshape; + RETURN_ERROR_IF_TRUE(input.Find("reshape", &reshape), TRITONSERVER_ERROR_UNSUPPORTED, + std::string("reshape not supported for input tensor")); + + std::string name; + RETURN_IF_ERROR(input.MemberAsString("name", &name)); + input_names_.push_back(name); + + std::string data_type; + RETURN_IF_ERROR(input.MemberAsString("data_type", &data_type)); + input_data_types_.push_back(ModelConfigDataTypeToTritonServerDataType(data_type)); + + std::string format; + RETURN_IF_ERROR(input.MemberAsString("format", &format)); + input_formats_.push_back(format); + } + + for (size_t i = 0; i < outputs.ArraySize(); ++i) { + common::TritonJson::Value output; + RETURN_IF_ERROR(outputs.IndexAsObject(i, &output)); + + triton::common::TritonJson::Value reshape; + RETURN_ERROR_IF_TRUE(output.Find("reshape", &reshape), TRITONSERVER_ERROR_UNSUPPORTED, + std::string("reshape not supported for output tensor")); + + std::string name; + RETURN_IF_ERROR(output.MemberAsString("name", &name)); + output_names_.push_back(name); + + std::string data_type; + RETURN_IF_ERROR(output.MemberAsString("data_type", &data_type)); + output_data_types_.push_back(ModelConfigDataTypeToTritonServerDataType(data_type)); + } + + return {}; +} + +::mmdeploy::Pipeline ModelState::CreatePipeline(TRITONSERVER_InstanceGroupKind kind, + int device_id) { + // infer device name + std::string device_name = "cpu"; + if (kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + device_name = "cuda"; + } + + std::string pipeline_template_path = + JoinPath({RepositoryPath(), std::to_string(Version()), "pipeline_template.json"}); + if (fs::exists(pipeline_template_path)) { + std::ifstream ifs(pipeline_template_path, std::ios::binary | std::ios::in); + ifs.seekg(0, std::ios::end); + auto size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::string str(size, '\0'); + ifs.read(str.data(), size); + + auto config = ::mmdeploy::from_json<::mmdeploy::Value>(nlohmann::json::parse(str)); + ::mmdeploy::Context context(::mmdeploy::Device(device_name, device_id)); + config["task_type"].get_to(task_type_); + config.object().erase("task_type"); + if (config.contains("model_names")) { + std::vector model_names; + ::mmdeploy::from_value(config["model_names"], model_names); + for (const auto& name : model_names) { + std::string model_path = JoinPath({RepositoryPath(), std::to_string(Version()), name}); + context.Add(name, ::mmdeploy::Model(model_path)); + } + config.object().erase("model_names"); + } + return {config, context}; + + } else { + model_ = ::mmdeploy::framework::Model(JoinPath({RepositoryPath(), std::to_string(Version())})); + auto config = model_.ReadConfig("pipeline.json").value(); + config["context"]["model"] = model_; + ::mmdeploy::Context context(::mmdeploy::Device(device_name, device_id)); + task_type_ = model_.meta().task; + return {config, context}; + } +} + +} // namespace triton::backend::mmdeploy diff --git a/csrc/mmdeploy/triton/model_state.h b/csrc/mmdeploy/triton/model_state.h new file mode 100644 index 0000000000..cbfaf00aa1 --- /dev/null +++ b/csrc/mmdeploy/triton/model_state.h @@ -0,0 +1,46 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MODEL_STATE_H +#define MMDEPLOY_MODEL_STATE_H + +#define MMDEPLOY_CXX_USE_OPENCV 0 + +#include "mmdeploy/core/model.h" +#include "mmdeploy/pipeline.hpp" +#include "triton/backend/backend_model.h" + +namespace triton::backend::mmdeploy { + +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, ModelState** state); + + const std::vector& input_names() const { return input_names_; } + const std::vector& output_names() const { return output_names_; } + const std::vector& input_data_types() const { return input_data_types_; } + const std::vector& output_data_types() const { return output_data_types_; } + + const std::vector& input_formats() const { return input_formats_; } + + ::mmdeploy::Pipeline CreatePipeline(TRITONSERVER_InstanceGroupKind kind, int device_id); + + const std::string& task_type() { return task_type_; } + + private: + explicit ModelState(TRITONBACKEND_Model* triton_model); + + TRITONSERVER_Error* ValidateModelConfig(); + + private: + ::mmdeploy::framework::Model model_; + std::string task_type_; + std::vector input_names_; + std::vector output_names_; + std::vector input_data_types_; + std::vector output_data_types_; + std::vector input_formats_; +}; + +} // namespace triton::backend::mmdeploy + +#endif // MMDEPLOY_MODEL_STATE_H diff --git a/demo/triton/image-classification/README.md b/demo/triton/image-classification/README.md new file mode 100644 index 0000000000..e7cb270009 --- /dev/null +++ b/demo/triton/image-classification/README.md @@ -0,0 +1,44 @@ +# Image classification serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmpretrain/classification_tensorrt_static-224x224.py \ + ../mmpretrain/configs/resnet/resnet18_8xb32_in1k.py \ + https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth \ + ../mmpretrain/demo/demo.JPEG \ + --device cuda \ + --work-dir work_dir/resnet \ + --dump-info +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/resnet \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/image-classification/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/image-classification/grpc_client.py b/demo/triton/image-classification/grpc_client.py new file mode 100644 index 0000000000..92755daf7b --- /dev/null +++ b/demo/triton/image-classification/grpc_client.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(results): + labels = results['labels'] + scores = results['scores'] + assert len(labels) == len(scores) + topk = len(labels) + print(f'top {topk} results:') + for i in range(topk): + print(f'label {labels[i]} score {scores[i]}') + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(results) diff --git a/demo/triton/image-classification/serving/model/1/README.md b/demo/triton/image-classification/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/image-classification/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/image-classification/serving/model/config.pbtxt b/demo/triton/image-classification/serving/model/config.pbtxt new file mode 100644 index 0000000000..c141614e46 --- /dev/null +++ b/demo/triton/image-classification/serving/model/config.pbtxt @@ -0,0 +1,20 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "scores" + data_type: TYPE_FP32 + dims: [ -1 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} diff --git a/demo/triton/instance-segmentation/README.md b/demo/triton/instance-segmentation/README.md new file mode 100644 index 0000000000..fabceeca4d --- /dev/null +++ b/demo/triton/instance-segmentation/README.md @@ -0,0 +1,44 @@ +# Instance segmentation serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py \ + ../mmdetection/configs/mask_rcnn/mask-rcnn_r50_fpn_2x_coco.py \ + https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_20200505_003907-3e542a40.pth \ + ../mmdetection/demo/demo.jpg \ + --work-dir work_dir/maskrcnn \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/maskrcnn \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/instance-segmentation/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/instance-segmentation/grpc_client.py b/demo/triton/instance-segmentation/grpc_client.py new file mode 100644 index 0000000000..6d9473ff0d --- /dev/null +++ b/demo/triton/instance-segmentation/grpc_client.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import math + +import cv2 +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + labels = results['labels'] + masks = results['masks'] + mask_offs = results['mask_offs'] + assert len(bboxes) == len(labels) + for i in range(len(bboxes)): + x1, y1, x2, y2, score = bboxes[i] + if score < 0.5: + continue + x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1) + + off, w, h = mask_offs[i] + mask_data = masks[off:off + w * h] + mask = mask_data.reshape(h, w) + + blue, green, red = cv2.split(img) + x0 = int(max(math.floor(x1) - 1, 0)) + y0 = int(max(math.floor(y1) - 1, 0)) + mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]] + cv2.bitwise_or(mask, mask_img, mask_img) + img = cv2.merge([blue, green, red]) + + cv2.imwrite('instance-segmentation.jpg', img) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/instance-segmentation/serving/model/1/README.md b/demo/triton/instance-segmentation/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/instance-segmentation/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/instance-segmentation/serving/model/config.pbtxt b/demo/triton/instance-segmentation/serving/model/config.pbtxt new file mode 100644 index 0000000000..8e8d145bdd --- /dev/null +++ b/demo/triton/instance-segmentation/serving/model/config.pbtxt @@ -0,0 +1,30 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 5 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} +output { + name: "masks" + data_type: TYPE_UINT8 + dims: [ -1 ] +} +output { + name: "mask_offs" + data_type: TYPE_INT32 + dims: [ -1, 3 ] +} diff --git a/demo/triton/keypoint-detection/README.md b/demo/triton/keypoint-detection/README.md new file mode 100644 index 0000000000..5839ac10bc --- /dev/null +++ b/demo/triton/keypoint-detection/README.md @@ -0,0 +1,44 @@ +# Keypoint detection serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmpose/pose-detection_tensorrt_static-256x192.py \ + ../mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py \ + https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192-81c58e40_20220909.pth \ + demo/resources/human-pose.jpg \ + --work-dir work_dir/hrnet \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/hrnet \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/keypoint-detection/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/keypoint-detection/grpc_client.py b/demo/triton/keypoint-detection/grpc_client.py new file mode 100644 index 0000000000..191ee54a91 --- /dev/null +++ b/demo/triton/keypoint-detection/grpc_client.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json + +import cv2 +import numpy as np +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image, box): + """ + Args: + image: np.ndarray + box: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [ + InferInput(self._input_names[0], image.shape, 'UINT8'), + InferInput(self._input_names[1], box.shape, 'BYTES') + ] + inputs[0].set_data_from_numpy(image) + inputs[1].set_data_from_numpy(box) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + batch_keypoints = results['keypoints'] + for keypoints in batch_keypoints: + n = keypoints.shape[0] + for i in range(n): + x, y, score = keypoints[i] + x, y = map(int, (x, y)) + cv2.circle(img, (x, y), 1, (0, 255, 0), 2) + cv2.imwrite('keypoint-detection.jpg', img) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + bbox = { + 'type': 'PoseBbox', + 'value': [{ + 'bbox': [0.0, 0.0, img.shape[1], img.shape[0]] + }] + } + bbox = np.array([json.dumps(bbox).encode('utf-8')]) + results = client.infer(img, bbox) + visualize(img, results) diff --git a/demo/triton/keypoint-detection/serving/model/1/README.md b/demo/triton/keypoint-detection/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/keypoint-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/keypoint-detection/serving/model/config.pbtxt b/demo/triton/keypoint-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..1877c3b0fa --- /dev/null +++ b/demo/triton/keypoint-detection/serving/model/config.pbtxt @@ -0,0 +1,29 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +input { + name: "PoseBbox" + data_type: TYPE_STRING + dims: [ 1 ] + allow_ragged_batch: true +} + +output { + name: "keypoints" + data_type: TYPE_INT32 + dims: [ -1, -1, 3 ] +} + +parameters { + key: "merge_inputs", + value: { + string_value: "0 1" + } +} diff --git a/demo/triton/object-detection/README.md b/demo/triton/object-detection/README.md new file mode 100644 index 0000000000..eff432c6e3 --- /dev/null +++ b/demo/triton/object-detection/README.md @@ -0,0 +1,44 @@ +# Object detection serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py \ + ../mmdetection/configs/retinanet/retinanet_r18_fpn_1x_coco.py \ + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth \ + ../mmdetection/demo/demo.jpg \ + --work-dir work_dir/retinanet \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/retinanet \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/object-detection/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/object-detection/grpc_client.py b/demo/triton/object-detection/grpc_client.py new file mode 100644 index 0000000000..321cabfcac --- /dev/null +++ b/demo/triton/object-detection/grpc_client.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + labels = results['labels'] + assert len(bboxes) == len(labels) + for i in range(len(bboxes)): + x1, y1, x2, y2, score = bboxes[i] + if score < 0.5: + continue + x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1) + cv2.imwrite('object-detection.jpg', img) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/object-detection/serving/model/1/README.md b/demo/triton/object-detection/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/object-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/object-detection/serving/model/config.pbtxt b/demo/triton/object-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..04913e0181 --- /dev/null +++ b/demo/triton/object-detection/serving/model/config.pbtxt @@ -0,0 +1,20 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 5 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} diff --git a/demo/triton/oriented-object-detection/README.md b/demo/triton/oriented-object-detection/README.md new file mode 100644 index 0000000000..670c53c483 --- /dev/null +++ b/demo/triton/oriented-object-detection/README.md @@ -0,0 +1,44 @@ +# Oriented object detection serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmrotate/rotated-detection_tensorrt_dynamic-320x320-1024x1024.py \ + ../mmrotate/configs/rotated_faster_rcnn/rotated-faster-rcnn-le90_r50_fpn_1x_dota.py \ + https://download.openmmlab.com/mmrotate/v0.1.0/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth \ + ../mmrotate/demo/demo.jpg \ + --dump-info \ + --work-dir work_dir/rrcnn \ + --device cuda +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/rrcnn \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/oriented-object-detection/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/oriented-object-detection/grpc_client.py b/demo/triton/oriented-object-detection/grpc_client.py new file mode 100644 index 0000000000..9e0525a068 --- /dev/null +++ b/demo/triton/oriented-object-detection/grpc_client.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from math import cos, sin + +import cv2 +import numpy as np +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + labels = results['labels'] + for rbbox, label_id in zip(bboxes, labels): + [cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1] + if score < 0.1: + continue + [wx, wy, hx, hy] = \ + 0.5 * np.array([w, w, -h, h]) * \ + np.array([cos(angle), sin(angle), sin(angle), cos(angle)]) + points = np.array([[[int(cx - wx - hx), + int(cy - wy - hy)], + [int(cx + wx - hx), + int(cy + wy - hy)], + [int(cx + wx + hx), + int(cy + wy + hy)], + [int(cx - wx + hx), + int(cy - wy + hy)]]]) + cv2.drawContours(img, points, -1, (0, 255, 0), 2) + cv2.imwrite('oriented-object-detection.jpg', img) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/oriented-object-detection/serving/model/1/README.md b/demo/triton/oriented-object-detection/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/oriented-object-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/oriented-object-detection/serving/model/config.pbtxt b/demo/triton/oriented-object-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..142570726f --- /dev/null +++ b/demo/triton/oriented-object-detection/serving/model/config.pbtxt @@ -0,0 +1,20 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 6 ] +} +output { + name: "labels" + data_type: TYPE_INT32 + dims: [ -1 ] +} diff --git a/demo/triton/semantic-segmentation/README.md b/demo/triton/semantic-segmentation/README.md new file mode 100644 index 0000000000..31e53a504a --- /dev/null +++ b/demo/triton/semantic-segmentation/README.md @@ -0,0 +1,44 @@ +# Semantic segmentation serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmseg/segmentation_tensorrt-fp16_static-512x1024.py \ + ../mmsegmentation/configs/pspnet/pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py \ + https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_512x1024_80k_cityscapes/pspnet_r18-d8_512x1024_80k_cityscapes_20201225_021458-09ffa746.pth \ + ../mmsegmentation/demo/demo.png \ + --work-dir work_dir/pspnet \ + --dump-info \ + --device cuda +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/pspnet \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/semantic-segmentation/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/semantic-segmentation/grpc_client.py b/demo/triton/semantic-segmentation/grpc_client.py new file mode 100644 index 0000000000..296723b034 --- /dev/null +++ b/demo/triton/semantic-segmentation/grpc_client.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +import numpy as np +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +def get_palette(num_classes=256): + state = np.random.get_state() + # random color + np.random.seed(42) + palette = np.random.randint(0, 256, size=(num_classes, 3)) + np.random.set_state(state) + return [tuple(c) for c in palette] + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + if 'mask' in results: + seg = results['mask'] + else: + score = results['score'] + seg = np.argmax(score, axis=0) + + palette = get_palette() + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * 0.5 + color_seg * 0.5 + img = img.astype(np.uint8) + cv2.imwrite('semantic-segmentation.png', img) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/semantic-segmentation/serving/mask/model/1/README.md b/demo/triton/semantic-segmentation/serving/mask/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/mask/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt b/demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt new file mode 100644 index 0000000000..e7a71e570d --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/mask/model/config.pbtxt @@ -0,0 +1,15 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [-1, -1, 3] + allow_ragged_batch: true +} + +output { + name: "mask" + data_type: TYPE_INT32 + dims: [ -1, -1 ] +} diff --git a/demo/triton/semantic-segmentation/serving/score/model/1/README.md b/demo/triton/semantic-segmentation/serving/score/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/score/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt b/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt new file mode 100644 index 0000000000..3fe8eda342 --- /dev/null +++ b/demo/triton/semantic-segmentation/serving/score/model/config.pbtxt @@ -0,0 +1,15 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [-1, -1, 3] + allow_ragged_batch: true +} + +output { + name: "score" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] +} diff --git a/demo/triton/text-detection/README.md b/demo/triton/text-detection/README.md new file mode 100644 index 0000000000..ac4074c4f3 --- /dev/null +++ b/demo/triton/text-detection/README.md @@ -0,0 +1,44 @@ +# Text detection serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py \ + ../mmocr/configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py \ + https://download.openmmlab.com/mmocr/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015/panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth \ + ../mmocr/demo/demo_text_det.jpg \ + --work-dir work_dir/panet \ + --dump-info \ + --device cuda:0 +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/panet \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/text-detection/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/text-detection/grpc_client.py b/demo/triton/text-detection/grpc_client.py new file mode 100644 index 0000000000..d93076b7ec --- /dev/null +++ b/demo/triton/text-detection/grpc_client.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(img, results): + bboxes = results['bboxes'] + scores = results['scores'] + for (bbox, score) in zip(bboxes, scores): + x = list(map(int, bbox[::2])) + y = list(map(int, bbox[1::2])) + n = len(x) + for i in range(n): + p1 = (x[i], y[i]) + p2 = (x[(i + 1) % n], y[(i + 1) % n]) + img = cv2.line(img, p1, p2, (0, 255, 0), 1) + cv2.imwrite('text-detection.jpg', img) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(img, results) diff --git a/demo/triton/text-detection/serving/model/1/README.md b/demo/triton/text-detection/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/text-detection/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/text-detection/serving/model/config.pbtxt b/demo/triton/text-detection/serving/model/config.pbtxt new file mode 100644 index 0000000000..b8afc6aa76 --- /dev/null +++ b/demo/triton/text-detection/serving/model/config.pbtxt @@ -0,0 +1,21 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [-1, -1, 3] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 8 ] +} + +output { + name: "scores" + data_type: TYPE_FP32 + dims: [ -1, 1 ] +} diff --git a/demo/triton/text-ocr/README.md b/demo/triton/text-ocr/README.md new file mode 100644 index 0000000000..acda4e62e8 --- /dev/null +++ b/demo/triton/text-ocr/README.md @@ -0,0 +1,56 @@ +# Text ocr serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy + +# text-detection +python3 tools/deploy.py \ + configs/mmocr/text-detection/text-detection_tensorrt_dynamic-320x320-2240x2240.py \ + ../mmocr/configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py \ + https://download.openmmlab.com/mmocr/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015/panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth \ + ../mmocr/demo/demo_text_det.jpg \ + --work-dir work_dir/panet \ + --dump-info \ + --device cuda:0 + +# text-recognition +python3 tools/deploy.py \ + configs/mmocr/text-recognition/text-recognition_tensorrt-fp16_dynamic-1x32x32-1x32x640.py \ + ../mmocr/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py \ + https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth \ + ../mmocr/demo/demo_text_recog.jpg \ + --work-dir work_dir/crnn \ + --device cuda \ + --dump-info +``` + +## Ensemble detection and recognition model + +``` +cd /root/workspace/mmdeploy +cp -r demo/triton/text-ocr/serving /model-repository +cp -r work_dir/panet/* /model-repository/model/1/text_detection/ +cp -r work_dir/crnn/* /model-repository/model/1/text_recognition/ +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/text-ocr/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/text-ocr/grpc_client.py b/demo/triton/text-ocr/grpc_client.py new file mode 100644 index 0000000000..7cf70cad29 --- /dev/null +++ b/demo/triton/text-ocr/grpc_client.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image): + """ + Args: + image: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [InferInput(self._input_names[0], image.shape, 'UINT8')] + inputs[0].set_data_from_numpy(image) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(results): + det_bboxes = results['bboxes'] + det_scores = results['scores'] + rec_texts = results['rec_texts'] + rec_scores = results['rec_scores'] + for i, (det_bbox, det_score, rec_text, rec_score) in \ + enumerate(zip(det_bboxes, det_scores, rec_texts, rec_scores)): + print(f'bbox[{i}] ({det_bbox[0]:.2f}, {det_bbox[1]:.2f}), ' + f'({det_bbox[2]:.2f}, {det_bbox[3]:.2f}), ({det_bbox[4]:.2f}, ' + f'{det_bbox[5]:.2f}), ({det_bbox[6]:.2f}, {det_bbox[7]:.2f}), ' + f'{det_score[0]:.2f}') + text = rec_text.decode('utf-8') + print(f'text[{i}] {text}') + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + results = client.infer(img) + visualize(results) diff --git a/demo/triton/text-ocr/serving/model/1/pipeline_template.json b/demo/triton/text-ocr/serving/model/1/pipeline_template.json new file mode 100644 index 0000000000..37ab975475 --- /dev/null +++ b/demo/triton/text-ocr/serving/model/1/pipeline_template.json @@ -0,0 +1,50 @@ +{ + "model_names": [ + "text_detection", + "text_recognition" + ], + "task_type": "TextOCR", + "type": "Pipeline", + "input": "img", + "output": [ + "dets", + "texts" + ], + "tasks": [ + { + "type": "Inference", + "input": "img", + "output": "dets", + "params": { + "model": "text_detection" + } + }, + { + "type": "Pipeline", + "input": [ + "bboxes=*dets", + "imgs=+img" + ], + "tasks": [ + { + "type": "Task", + "module": "WarpBbox", + "input": [ + "imgs", + "bboxes" + ], + "output": "patches" + }, + { + "type": "Inference", + "input": "patches", + "output": "texts", + "params": { + "model": "text_recognition" + } + } + ], + "output": "*texts" + } + ] +} diff --git a/demo/triton/text-ocr/serving/model/1/text_detection/README.md b/demo/triton/text-ocr/serving/model/1/text_detection/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/text-ocr/serving/model/1/text_recognition/README.md b/demo/triton/text-ocr/serving/model/1/text_recognition/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demo/triton/text-ocr/serving/model/config.pbtxt b/demo/triton/text-ocr/serving/model/config.pbtxt new file mode 100644 index 0000000000..9864802ad8 --- /dev/null +++ b/demo/triton/text-ocr/serving/model/config.pbtxt @@ -0,0 +1,33 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +output { + name: "bboxes" + data_type: TYPE_FP32 + dims: [ -1, 8 ] +} + +output { + name: "scores" + data_type: TYPE_FP32 + dims: [ -1, 1 ] +} + +output { + name: "rec_texts" + data_type: TYPE_STRING + dims: [ -1, 1] +} + +output { + name: "rec_scores" + data_type: TYPE_STRING + dims: [ -1, 1 ] +} diff --git a/demo/triton/text-recognition/README.md b/demo/triton/text-recognition/README.md new file mode 100644 index 0000000000..38c4b72cad --- /dev/null +++ b/demo/triton/text-recognition/README.md @@ -0,0 +1,44 @@ +# Text recognition serving + +## Starting a docker container + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +## Convert pytorch model to tensorrt model + +``` +cd /root/workspace/mmdeploy +python3 tools/deploy.py \ + configs/mmocr/text-recognition/text-recognition_tensorrt-fp16_dynamic-1x32x32-1x32x640.py \ + ../mmocr/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py \ + https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth \ + ../mmocr/demo/demo_text_recog.jpg \ + --work-dir work_dir/crnn \ + --device cuda \ + --dump-info +``` + +## Convert tensorrt model to triton format + +``` +cd /root/workspace/mmdeploy +python3 demo/triton/to_triton_model.py \ + /root/workspace/mmdeploy/work_dir/crnn \ + /model-repository +``` + +## Start triton server + +``` +tritonserver --model-repository=/model-repository +``` + +## Run client code output container + +``` +python3 demo/triton/text-detection/grpc_client.py \ + model \ + /path/to/image +``` diff --git a/demo/triton/text-recognition/grpc_client.py b/demo/triton/text-recognition/grpc_client.py new file mode 100644 index 0000000000..8e5384b252 --- /dev/null +++ b/demo/triton/text-recognition/grpc_client.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json + +import cv2 +import numpy as np +from tritonclient.grpc import (InferenceServerClient, InferInput, + InferRequestedOutput) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model_name', type=str, help='model name') + parser.add_argument('image', type=str, help='image path') + return parser.parse_args() + + +class GRPCTritonClient: + + def __init__(self, url, model_name, model_version): + self._url = url + self._model_name = model_name + self._model_version = model_version + self._client = InferenceServerClient(self._url) + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata( + self._model_name, self._model_version) + print(f'[model config]:\n{model_config}') + print(f'[model metadata]:\n{model_metadata}') + self._inputs = {input.name: input for input in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = { + output.name: output + for output in model_metadata.outputs + } + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def infer(self, image, box): + """ + Args: + image: np.ndarray + box: np.ndarray + Returns: + results: dict, {name : numpy.array} + """ + + inputs = [ + InferInput(self._input_names[0], image.shape, 'UINT8'), + InferInput(self._input_names[1], box.shape, 'BYTES') + ] + inputs[0].set_data_from_numpy(image) + inputs[1].set_data_from_numpy(box) + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=inputs, + outputs=self._outputs_req) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + +def visualize(results): + texts = results['texts'] + scores = results['scores'] + for box_texts_, box_scores_ in zip(texts, scores): + box_texts = box_texts_.decode('utf-8') + box_scores = json.loads(box_scores_.decode('utf-8')) + print(box_texts, box_scores) + + +if __name__ == '__main__': + args = parse_args() + model_name = args.model_name + model_version = '1' + url = 'localhost:8001' + client = GRPCTritonClient(url, model_name, model_version) + img = cv2.imread(args.image) + bbox = { + 'type': + 'TextBbox', + 'value': [{ + 'bbox': [ + 0.0, 0.0, img.shape[1], 0, img.shape[1], img.shape[0], 0, + img.shape[0] + ], + }] + } + bbox = np.array([json.dumps(bbox).encode('utf-8')]) + results = client.infer(img, bbox) + visualize(results) diff --git a/demo/triton/text-recognition/serving/model/1/README.md b/demo/triton/text-recognition/serving/model/1/README.md new file mode 100644 index 0000000000..3b5ec0b47e --- /dev/null +++ b/demo/triton/text-recognition/serving/model/1/README.md @@ -0,0 +1 @@ +This directory holds the model files. diff --git a/demo/triton/text-recognition/serving/model/config.pbtxt b/demo/triton/text-recognition/serving/model/config.pbtxt new file mode 100644 index 0000000000..759911a09b --- /dev/null +++ b/demo/triton/text-recognition/serving/model/config.pbtxt @@ -0,0 +1,28 @@ +backend: "mmdeploy" + +input { + name: "ori_img" + data_type: TYPE_UINT8 + format: FORMAT_NHWC + dims: [ -1, -1, 3 ] + allow_ragged_batch: true +} + +input { + name: "TextBbox" + data_type: TYPE_STRING + dims: [ 1 ] + allow_ragged_batch: true +} + +output { + name: "texts" + data_type: TYPE_STRING + dims: [ -1, 1] +} + +output { + name: "scores" + data_type: TYPE_STRING + dims: [ -1, 1 ] +} diff --git a/demo/triton/to_triton_model.py b/demo/triton/to_triton_model.py new file mode 100644 index 0000000000..81cb864ad1 --- /dev/null +++ b/demo/triton/to_triton_model.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os +import os.path as osp +import shutil +from enum import Enum, unique +from glob import glob + +BASEDIR = os.path.dirname(__file__) + + +@unique +class Template(str, Enum): + ImageClassification = 'image-classification/serving' + InstanceSegmentation = 'instance-segmentation/serving' + KeypointDetection = 'keypoint-detection/serving' + ObjectDetection = 'object-detection/serving' + OrientedObjectDetection = 'oriented-object-detection/serving' + SemanticSegmentation1 = 'semantic-segmentation/serving/mask' + SemanticSegmentation2 = 'semantic-segmentation/serving/score' + TextRecognition = 'text-recognition/serving' + TextDetection = 'text-detection/serving' + + +def copy_template(src_folder, dst_folder): + files = glob(osp.join(src_folder, '*')) + for src in files: + dst = osp.join(dst_folder, osp.basename(src)) + if osp.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy(src, dst) + + +class Convert: + + def __init__(self, model_type, model_dir, deploy_cfg, pipeline_cfg, + detail_cfg, output_dir): + self._model_type = model_type + self._model_dir = model_dir + self._deploy_cfg = deploy_cfg + self._pipeline_cfg = pipeline_cfg + self._detail_cfg = detail_cfg + self._output_dir = output_dir + + def copy_file(self, file_name, src_folder, dst_folder): + src_path = osp.join(src_folder, file_name) + dst_path = osp.join(dst_folder, file_name) + if osp.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy(src_path, dst_path) + + def write_json_file(self, data, file_name, dst_folder): + dst_path = osp.join(dst_folder, file_name) + with open(dst_path, 'w') as f: + json.dump(data, f, indent=4) + + def create_single_model(self): + output_model_folder = osp.join(self._output_dir, 'model', '1') + if (self._model_type == Template.TextRecognition): + self._pipeline_cfg['pipeline']['input'].append('bbox') + self._pipeline_cfg['pipeline']['tasks'][0]['input'] = ['patch'] + warpbbox = { + 'type': 'Task', + 'module': 'WarpBbox', + 'input': ['img', 'bbox'], + 'output': ['patch'] + } + self._pipeline_cfg['pipeline']['tasks'].insert(0, warpbbox) + self.write_json_file(self._pipeline_cfg, 'pipeline.json', + output_model_folder) + else: + self.copy_file('pipeline.json', self._model_dir, + output_model_folder) + + self.copy_file('deploy.json', self._model_dir, output_model_folder) + models = self._deploy_cfg['models'] + for model in models: + net = model['net'] + self.copy_file(net, self._model_dir, output_model_folder) + for custom in self._deploy_cfg['customs']: + self.copy_file(custom, self._model_dir, output_model_folder) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'model_dir', + type=str, + help='converted model dir with ' + '`--dump-info` flag when convert the model') + parser.add_argument('output_dir', type=str, help='output dir') + return parser.parse_args() + + +def get_model_type(detail_cfg, pipeline_cfg): + task = detail_cfg['codebase_config']['task'] + output_names = detail_cfg['onnx_config']['output_names'] + + if task == 'Classification': + return Template.ImageClassification + if task == 'ObjectDetection': + if 'masks' in output_names: + return Template.InstanceSegmentation + else: + return Template.ObjectDetection + if task == 'Segmentation': + with_argmax = pipeline_cfg['pipeline']['tasks'][-1]['params'].get( + 'with_argmax', True) + if with_argmax: + return Template.SemanticSegmentation1 + else: + return Template.SemanticSegmentation2 + if task == 'PoseDetection': + return Template.KeypointDetection + if task == 'RotatedDetection': + return Template.OrientedObjectDetection + if task == 'TextRecognition': + return Template.TextRecognition + if task == 'TextDetection': + return Template.TextDetection + + assert 0, f'doesn\'t support task {task} with output_names: {output_names}' + + +if __name__ == '__main__': + args = parse_args() + model_dir = args.model_dir + output_dir = args.output_dir + + # check + assert osp.isdir(model_dir), f'model dir {model_dir} doesn\'t exist' + info_files = ['deploy.json', 'pipeline.json', 'detail.json'] + for file in info_files: + path = osp.join(model_dir, file) + assert osp.exists(path), f'{path} doesn\'t exist in {model_dir}' + + with open(osp.join(model_dir, 'deploy.json')) as f: + deploy_cfg = json.load(f) + with open(osp.join(model_dir, 'pipeline.json')) as f: + pipeline_cfg = json.load(f) + with open(osp.join(model_dir, 'detail.json')) as f: + detail_cfg = json.load(f) + assert 'onnx_config' in detail_cfg, \ + 'currently, only support onnx as middle ir' + + # process + model_type = get_model_type(detail_cfg, pipeline_cfg) + convert = Convert(model_type, model_dir, deploy_cfg, pipeline_cfg, + detail_cfg, output_dir) + + src_folder = osp.join(BASEDIR, model_type.value) + + if not osp.exists(output_dir): + os.makedirs(output_dir) + copy_template(src_folder, output_dir) + convert.create_single_model() diff --git a/docker/triton/Dockerfile b/docker/triton/Dockerfile new file mode 100644 index 0000000000..a1e9a16adb --- /dev/null +++ b/docker/triton/Dockerfile @@ -0,0 +1,78 @@ +FROM nvcr.io/nvidia/tritonserver:22.12-pyt-python-py3 + +ARG CUDA=11.3 +ARG TORCH_VERSION=1.10.0 +ARG TORCHVISION_VERSION=0.11.0 +ARG ONNXRUNTIME_VERSION=1.8.1 +ARG PPLCV_VERSION=0.7.0 +ENV FORCE_CUDA="1" +ARG MMCV_VERSION=">=2.0.0rc2" +ARG MMENGINE_VERSION=">=0.3.0" + +WORKDIR /root/workspace + +RUN wget https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.sh &&\ + bash cmake-3.26.3-linux-x86_64.sh --skip-license --prefix=/usr + +RUN git clone --depth 1 --branch v${PPLCV_VERSION} https://github.com/openppl-public/ppl.cv.git &&\ + cd ppl.cv &&\ + ./build.sh cuda &&\ + mv cuda-build/install ./ &&\ + rm -rf cuda-build +ENV pplcv_DIR=/root/workspace/ppl.cv/install/lib/cmake/ppl + +RUN apt-get update &&\ + apt-get install -y libopencv-dev + +RUN wget https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}.tgz \ + && tar -zxvf onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}.tgz +ENV ONNXRUNTIME_DIR=/root/workspace/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION} +ENV LD_LIBRARY_PATH=/root/workspace/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}/lib:$LD_LIBRARY_PATH + +RUN python3 -m pip install -U pip &&\ + pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html &&\ + pip install openmim &&\ + mim install "mmcv"${MMCV_VERSION} onnxruntime-gpu==${ONNXRUNTIME_VERSION} mmengine${MMENGINE_VERSION} &&\ + ln /usr/bin/python3 /usr/bin/python + +COPY TensorRT-8.2.3.0 /root/workspace/tensorrt +RUN pip install /root/workspace/tensorrt/python/*cp38*whl +ENV TENSORRT_DIR=/root/workspace/tensorrt +ENV LD_LIBRARY_PATH=/root/workspace/tensorrt/lib:$LD_LIBRARY_PATH + +RUN apt-get install -y rapidjson-dev + +RUN git clone -b v1.0.0rc7 https://github.com/open-mmlab/mmpretrain.git &&\ + cd mmpretrain && pip install . + +RUN git clone -b v3.0.0 https://github.com/open-mmlab/mmdetection.git &&\ + cd mmdetection && pip install . + +RUN git clone -b v1.0.0 https://github.com/open-mmlab/mmsegmentation.git &&\ + cd mmsegmentation && pip install . + +RUN git clone -b v1.0.0 https://github.com/open-mmlab/mmocr.git &&\ + cd mmocr && pip install . + +RUN git clone -b v1.0.0rc1 https://github.com/open-mmlab/mmrotate.git &&\ + cd mmrotate && pip install . + +RUN git clone -b v1.0.0 https://github.com/open-mmlab/mmpose.git &&\ + cd mmpose && pip install . + +RUN git clone -b triton-server --recursive https://github.com/irexyc/mmdeploy &&\ + cd mmdeploy && mkdir -p build && cd build &&\ + cmake .. \ + -DMMDEPLOY_BUILD_SDK=ON \ + -DMMDEPLOY_TARGET_DEVICES="cuda;cpu" \ + -DMMDEPLOY_BUILD_TEST=OFF \ + -DMMDEPLOY_TARGET_BACKENDS="trt;ort" \ + -DMMDEPLOY_CODEBASES=all \ + -Dpplcv_DIR=${pplcv_DIR} \ + -DMMDEPLOY_BUILD_EXAMPLES=OFF \ + -DMMDEPLOY_DYNAMIC_BACKEND=OFF \ + -DTRITON_MMDEPLOY_BACKEND=ON \ + -DTRITON_TAG="r22.12" &&\ + make -j$(nproc) && make install &&\ + cp -r install/backends /opt/tritonserver/ &&\ + cd .. && pip install -e . --user diff --git a/docs/en/02-how-to-run/triton_server.md b/docs/en/02-how-to-run/triton_server.md new file mode 100644 index 0000000000..a8fe4b1df1 --- /dev/null +++ b/docs/en/02-how-to-run/triton_server.md @@ -0,0 +1,42 @@ +# Model serving + +MMDeploy provides model server deployment based on Triton Inference Server. + +## Supported tasks + +The following tasks are currently supported: + +- [image-classification](../../../demo/triton/image-classification/README.md) +- [instance-segmentation](../../../demo/triton/instance-segmentation) +- [keypoint-detection](../../../demo/triton/keypoint-detection) +- [object-detection](../../../demo/triton/object-detection) +- [oriented-object-detection](../../../demo/triton/oriented-object-detection) +- [semantic-segmentation](../../../demo/triton/semantic-segmentation) +- [text-detection](../../../demo/triton/text-detection) +- [text-recognition](../../../demo/triton/text-recognition) +- [text-ocr](../../../demo/triton/text-ocr) + +## Run Triton + +In order to use Triton Inference Server, we need: + +1. Compile MMDeploy Triton Backend +2. Prepare the model repository (including model files, and configuration files) + +### Compile MMDeploy Triton Backend + +a) Using Docker images + +For ease of use, we provide a Docker image to support the deployment of models converted by MMDeploy. The image supports Tensorrt and ONNX Runtime as backends. If you need other backends, you can choose build from source. + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +b) Build from source + +You can refer [build from source](../01-how-to-build/build_from_source.md) to build MMDeploy. In order to build MMDeploy Triton Backend, you need to add `-DTRITON_MMDEPLOY_BACKEND=ON` to cmake configure command. By default, the latest version of Triton Backend is used. If you want to use an older version of Triton Backend, you can add `-DTRITON_TAG=r22.12` to the cmake configure command. + +### Prepare the model repository + +Triton Inference Server has its own model description rules. Therefore the models converted through `tools/deploy.py ... --dump-info` need to be formatted to make Triton load correctly. We have prepared templates for each task. You can use `demo/triton/to_triton_model.py` script for model formatting. For complete samples, please refer to the description of each demo. diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 9fce872ea1..f3afdb1080 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -330,6 +330,10 @@ We'll talk about them more in our next release. If you want to fuse preprocess for acceleration,please refer to this [doc](./02-how-to-run/fuse_transform.md) +## Model serving (triton) + +For server-side deployment, please read [model serving](02-how-to-run/triton_server.md) for more details. + ## Evaluate Model You can test the performance of deployed model using `tool/test.py`. For example, diff --git a/docs/en/index.rst b/docs/en/index.rst index 0704aeaf8f..1833ef1d0c 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -27,6 +27,7 @@ You can switch between Chinese and English documents in the lower-left corner of 02-how-to-run/profile_model.md 02-how-to-run/quantize_model.md 02-how-to-run/useful_tools.md + 02-how-to-run/triton_server.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/02-how-to-run/triton_server.md b/docs/zh_cn/02-how-to-run/triton_server.md new file mode 100644 index 0000000000..bc25d6da03 --- /dev/null +++ b/docs/zh_cn/02-how-to-run/triton_server.md @@ -0,0 +1,42 @@ +# 如何进行服务端部署 + +模型转换后,MMDeploy 提供基于 Triton Inference Server 的模型服务端部署。 + +## 支持的任务 + +目前支持以下任务: + +- [image-classification](../../../demo/triton/image-classification/README.md) +- [instance-segmentation](../../../demo/triton/instance-segmentation) +- [keypoint-detection](../../../demo/triton/keypoint-detection) +- [object-detection](../../../demo/triton/object-detection) +- [oriented-object-detection](../../../demo/triton/oriented-object-detection) +- [semantic-segmentation](../../../demo/triton/semantic-segmentation) +- [text-detection](../../../demo/triton/text-detection) +- [text-recognition](../../../demo/triton/text-recognition) +- [text-ocr](../../../demo/triton/text-ocr) + +## 如何部署 Triton 服务 + +为了使用 Triton Inference Server, 我们需要: + +1. 编译 MMDeploy Triton Backend +2. 准备模型库(包括模型文件,以及配置文件) + +### 编译 MMDeploy Triton Backend + +a) 使用 Docker 镜像 + +为了方便使用,我们提供了 Docker 镜像,支持对通过 MMDeploy 转换的模型进行部署。镜像支持 Tensorrt 以及 ONNX Runtime 作为后端。若需要其他后端,可选择从源码进行编译。 + +``` +docker run -it --rm --gpus all openmmlab/mmdeploy:triton-22.12 +``` + +b) 从源码编译 + +从源码编译 MMDeploy 的方式可参考[源码手动安装](../01-how-to-build/build_from_source.md),要编译 MMDeploy Triton Backend,需要在编译命令中添加:`-DTRITON_MMDEPLOY_BACKEND=ON`。默认使用最新版本的 Triton Backend,若要使用旧版本的 Triton Backend,可在编译命令中添加`-DTRITON_TAG=r22.12` + +### 准备模型库 + +Triton Inference Server 有一套自己的模型描述规则,通过 `tools/deploy.py ... --dump-info ` 转换的模型需要调整格式才能使 Triton 正确加载,我们为各任务准备了模版,可以运行 `demo/triton/to_triton_model.py` 转换脚本格式进行修改。完整的样例可参考各个 demo 的说明。 diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index 27b4e55245..ed85e52d25 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -331,6 +331,10 @@ target_link_libraries(${name} PRIVATE mmdeploy ${OpenCV_LIBS}) 若要对预处理进行加速,请查阅[此处](./02-how-to-run/fuse_transform.md) +## 服务端部署 (triton) + +若需要进行服务端部署,请阅读 [服务端部署](02-how-to-run/triton_server.md) 了解更多细节 + ## 模型精度评估 为了测试部署模型的精度,推理效率,我们提供了 `tools/test.py` 来帮助完成相关工作。以上文中的部署模型为例: diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index e52a40c7aa..12959be2aa 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -27,6 +27,7 @@ 02-how-to-run/profile_model.md 02-how-to-run/quantize_model.md 02-how-to-run/useful_tools.md + 02-how-to-run/triton_server.md .. toctree:: :maxdepth: 1