diff --git a/cpp/build.sh b/cpp/build.sh index 23df4df722..79fc385e51 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -25,7 +25,6 @@ function install_dependencies_linux() { autoconf \ automake \ git \ - cmake \ m4 \ g++ \ flex \ @@ -175,6 +174,14 @@ function install_libtorch() { wget https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.1%2Bcu116.zip unzip libtorch-cxx11-abi-shared-with-deps-1.12.1+cu116.zip rm libtorch-cxx11-abi-shared-with-deps-1.12.1+cu116.zip + elif [ "$CUDA" = "cu117" ]; then + wget https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu117.zip + unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117.zip + rm libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117.zip + elif [ "$CUDA" = "cu118" ]; then + wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu118.zip + unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cu118.zip + rm libtorch-cxx11-abi-shared-with-deps-2.0.1+cu118.zip else wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.12.1%2Bcpu.zip unzip libtorch-cxx11-abi-shared-with-deps-1.12.1+cpu.zip @@ -254,7 +261,7 @@ function build() { find $FOLLY_CMAKE_DIR -name "lib*.*" -exec ln -s "{}" $LIBS_DIR/ \; if [ "$PLATFORM" = "Linux" ]; then cmake \ - -DCMAKE_PREFIX_PATH="$DEPS_DIR;$FOLLY_CMAKE_DIR;$YAML_CPP_CMAKE_DIR;$DEPS_DIR/libtorch" \ + -DCMAKE_PREFIX_PATH="$DEPS_DIR;$FOLLY_CMAKE_DIR;$YAML_CPP_CMAKE_DIR;$DEPS_DIR/libtorch;" \ -DCMAKE_INSTALL_PREFIX="$PREFIX" \ "$MAYBE_BUILD_QUIC" \ "$MAYBE_BUILD_TESTS" \ @@ -265,7 +272,7 @@ function build() { "$MAYBE_CUDA_COMPILER" \ .. - if [ "$CUDA" = "cu102" ] || [ "$CUDA" = "cu113" ] || [ "$CUDA" = "cu116" ]; then + if [ "$CUDA" = "cu102" ] || [ "$CUDA" = "cu113" ] || [ "$CUDA" = "cu116" ] || [ "$CUDA" = "cu117" ] || [ "$CUDA" = "cu118" ]; then export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/bin/nvcc fi elif [ "$PLATFORM" = "Mac" ]; then @@ -299,6 +306,10 @@ function build() { mv $DEPS_DIR/../src/examples/libmnist_handler.so $DEPS_DIR/../../test/resources/torchscript_model/mnist/mnist_handler/libmnist_handler.so fi + if [ -f "$DEPS_DIR/../src/examples/libresnet-18_handler.so" ]; then + mv $DEPS_DIR/../src/examples/libresnet-18_handler.so $DEPS_DIR/../../test/resources/torchscript_model/resnet-18/resnet-18_handler/libresnet-18_handler.so + fi + cd $DEPS_DIR/../.. if [ -f "$DEPS_DIR/../test/torchserve_cpp_test" ]; then $DEPS_DIR/../test/torchserve_cpp_test @@ -329,7 +340,7 @@ INSTALL_DEPENDENCIES=false PREFIX="" COMPILER_FLAGS="" CUDA="" -USAGE="./build.sh [-j num_jobs] [-g cu102|cu113|cu116] [-q|--with-quic] [--install-dependencies] [-p|--prefix] [-x|--compiler-flags]" +USAGE="./build.sh [-j num_jobs] [-g cu102|cu113|cu116|cu117|cu118] [-q|--with-quic] [--install-dependencies] [-p|--prefix] [-x|--compiler-flags]" while [ "$1" != "" ]; do case $1 in -j | --jobs ) shift diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 4c9c534097..b9e778fd1c 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -1,7 +1,3 @@ -set(MNIST_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/image_classifier/mnist") +find_package(OpenCV REQUIRED) -set(MNIST_SOURCE_FILES "") -list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) -add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) -target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) -target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +add_subdirectory(image_classifier) diff --git a/cpp/src/examples/image_classifier/CMakeLists.txt b/cpp/src/examples/image_classifier/CMakeLists.txt new file mode 100644 index 0000000000..a5a5456040 --- /dev/null +++ b/cpp/src/examples/image_classifier/CMakeLists.txt @@ -0,0 +1,19 @@ +set(MNIST_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/image_classifier/mnist") + +set(MNIST_SOURCE_FILES "") +list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) +add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) +target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) +target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) + +set(RESNET_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/image_classifier/resnet-18") + +set(RESNET_SOURCE_FILES "") + +list(APPEND RESNET_SOURCE_FILES ${RESNET_SRC_DIR}/resnet-18_handler.cc) +add_library(resnet-18_handler SHARED ${RESNET_SOURCE_FILES}) +target_include_directories(resnet-18_handler PUBLIC ${OPENCV_DIR}) +target_include_directories(resnet-18_handler PUBLIC ${RESNET_SRC_DIR}) +target_link_libraries(resnet-18_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +include_directories( ${OpenCV_INCLUDE_DIRS} ) +target_link_libraries( resnet-18_handler PRIVATE ${OpenCV_LIBS} ) diff --git a/cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc b/cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc new file mode 100644 index 0000000000..06b71098b2 --- /dev/null +++ b/cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc @@ -0,0 +1,221 @@ +#include "src/examples/image_classifier/resnet-18/resnet-18_handler.hh" + +#include + +#include +#include +#include +#include + +namespace resnet { + +constexpr int kTargetImageSize = 224; +constexpr double kImageNormalizationMeanR = 0.485; +constexpr double kImageNormalizationMeanG = 0.456; +constexpr double kImageNormalizationMeanB = 0.406; +constexpr double kImageNormalizationStdR = 0.229; +constexpr double kImageNormalizationStdG = 0.224; +constexpr double kImageNormalizationStdB = 0.225; +constexpr int kTopKClasses = 5; + +std::vector ResnetHandler::Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) { + std::vector batch_ivalue; + std::vector batch_tensors; + uint8_t idx = 0; + for (auto& request : *request_batch) { + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + auto data_it = + request.parameters.find(torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + try { + if (dtype_it->second == torchserve::PayloadType::kDATA_TYPE_BYTES) { + cv::Mat image = cv::imdecode(data_it->second, cv::IMREAD_COLOR); + + // Check if the image was successfully decoded + if (image.empty()) { + std::cerr << "Failed to decode the image.\n"; + } + + const int rows = image.rows; + const int cols = image.cols; + + const int cropSize = std::min(rows, cols); + const int offsetW = (cols - cropSize) / 2; + const int offsetH = (rows - cropSize) / 2; + + const cv::Rect roi(offsetW, offsetH, cropSize, cropSize); + image = image(roi); + + // Convert the image to GPU Mat + cv::cuda::GpuMat gpuImage; + cv::Mat resultImage; + + gpuImage.upload(image); + + // Resize on GPU + cv::cuda::resize(gpuImage, gpuImage, + cv::Size(kTargetImageSize, kTargetImageSize)); + + // Convert to BGR on GPU + cv::cuda::cvtColor(gpuImage, gpuImage, cv::COLOR_BGR2RGB); + + // Convert to float on GPU + gpuImage.convertTo(gpuImage, CV_32FC3, 1 / 255.0); + + // Download the final image from GPU to CPU + gpuImage.download(resultImage); + + // Create a tensor from the CPU Mat + torch::Tensor tensorImage = torch::from_blob( + resultImage.data, {resultImage.rows, resultImage.cols, 3}, + torch::kFloat); + tensorImage = tensorImage.permute({2, 0, 1}); + + std::vector norm_mean = {kImageNormalizationMeanR, + kImageNormalizationMeanG, + kImageNormalizationMeanB}; + std::vector norm_std = {kImageNormalizationStdR, + kImageNormalizationStdG, + kImageNormalizationStdB}; + + // Normalize the tensor + tensorImage = torch::data::transforms::Normalize<>( + norm_mean, norm_std)(tensorImage); + + tensorImage.clone(); + batch_tensors.emplace_back(tensorImage.to(*device)); + idx_to_req_id.second[idx++] = request.request_id; + } else if (dtype_it->second == "List") { + // case3: the image is a list + } + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error: {}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } + if (!batch_tensors.empty()) { + batch_ivalue.emplace_back(torch::stack(batch_tensors).to(*device)); + } + + return batch_ivalue; +} + +void ResnetHandler::Postprocess( + const torch::Tensor& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) { + for (const auto& kv : idx_to_req_id.second) { + try { + auto response = (*response_batch)[kv.second]; + namespace F = torch::nn::functional; + + // Perform softmax and top-k operations + torch::Tensor ps = F::softmax(data, F::SoftmaxFuncOptions(1)); + std::tuple result = + torch::topk(ps, kTopKClasses, 1, true, true); + torch::Tensor probs = std::get<0>(result); + torch::Tensor classes = std::get<1>(result); + + probs = probs.to(torch::kCPU); + classes = classes.to(torch::kCPU); + // Convert tensors to C++ vectors + std::vector probs_vector(probs.data_ptr(), + probs.data_ptr() + probs.numel()); + std::vector classes_vector( + classes.data_ptr(), classes.data_ptr() + classes.numel()); + + // Create a JSON object using folly::dynamic + folly::dynamic json_response = folly::dynamic::object; + // Create a folly::dynamic array to hold tensor elements + folly::dynamic probability = folly::dynamic::array; + folly::dynamic class_names = folly::dynamic::array; + + // Iterate through tensor elements and add them to the dynamic_array + for (const float& value : probs_vector) { + probability.push_back(value); + } + for (const long& value : classes_vector) { + class_names.push_back(value); + } + // Add key-value pairs to the JSON object + json_response["probability"] = probability; + json_response["classes"] = class_names; + + // Serialize the JSON object to a string + std::string json_str = folly::toJson(json_response); + + // Serialize and set the response + response->SetResponse(200, "data_tpye", + torchserve::PayloadType::kDATA_TYPE_BYTES, + json_str); + } catch (const std::runtime_error& e) { + LOG(ERROR) << "Failed to load tensor for request id:" << kv.second + << ", error: " << e.what(); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_tpye", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + throw e; + } catch (const c10::Error& e) { + LOG(ERROR) << "Failed to load tensor for request id:" << kv.second + << ", c10 error: " << e.msg(); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_tpye", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + throw e; + } + } +} + +} // namespace resnet + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::torchscripted::BaseHandler* allocatorResnetHandler() { + return new resnet::ResnetHandler(); +} + +void deleterResnetHandler(torchserve::torchscripted::BaseHandler* p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.hh b/cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.hh new file mode 100644 index 0000000000..d371d61a46 --- /dev/null +++ b/cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.hh @@ -0,0 +1,28 @@ +#ifndef RESNET_HANDLER_HH_ +#define RESNET_HANDLER_HH_ + +#include "src/backends/torch_scripted/handler/base_handler.hh" + +namespace resnet { +class ResnetHandler : public torchserve::torchscripted::BaseHandler { + public: + // NOLINTBEGIN(bugprone-exception-escape) + ResnetHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~ResnetHandler() override = default; + + std::vector Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + void Postprocess( + const torch::Tensor& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; +}; +} // namespace resnet +#endif // RESNET_HANDLER_HH_ diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index b3099d1a2a..13dde46da2 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -78,6 +78,16 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandler) { "mnist_ts", 200); } +TEST_F(TorchScriptedBackendTest, TestLoadPredictResnetHandler) { + this->LoadPredict( + std::make_shared( + "test/resources/torchscript_model/resnet-18/resnet-18_handler", + "resnet-18", -1, "", "", 1, false), + "test/resources/torchscript_model/resnet-18/resnet-18_handler", + "test/resources/torchscript_model/resnet-18/kitten.jpg", "resnet-18_ts", + 200); +} + TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) { auto result = backend_->Initialize("test/resources/torchscript_model/mnist"); ASSERT_EQ(result, false); diff --git a/cpp/test/resources/torchscript_model/resnet-18/kitten.jpg b/cpp/test/resources/torchscript_model/resnet-18/kitten.jpg new file mode 100644 index 0000000000..ffcd2be2c6 Binary files /dev/null and b/cpp/test/resources/torchscript_model/resnet-18/kitten.jpg differ diff --git a/cpp/test/resources/torchscript_model/resnet-18/resnet-18_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/torchscript_model/resnet-18/resnet-18_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..84bfacdcb6 --- /dev/null +++ b/cpp/test/resources/torchscript_model/resnet-18/resnet-18_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,11 @@ +{ + "createdOn": "28/07/2020 06:32:08", + "runtime": "LSP", + "model": { + "modelName": "resnet-18", + "serializedFile": "resnet-18.pt", + "handler": "libresnet-18_handler:ResnetHandler", + "modelVersion": "2.0" + }, + "archiverVersion": "0.2.0" +} \ No newline at end of file