diff --git a/.gitmodules b/.gitmodules index a2584b165a..3125a3b997 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "cpp/third-party/llama2.c"] path = cpp/third-party/llama2.c url = https://github.com/karpathy/llama2.c +[submodule "cpp/third-party/llama2.so"] + path = cpp/third-party/llama2.so + url = https://github.com/mreso/llama2.so.git diff --git a/cpp/README.md b/cpp/README.md index 4f7dd53318..70b96339b9 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -42,7 +42,7 @@ By default, TorchServe cpp provides a handler for TorchScript [src/backends/hand * [Preprocess](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L40) * [Inference](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L46) * [Postprocess](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L53) -#### Example +#### Usage ##### Using TorchScriptHandler * set runtime as "LSP" in model archiver option [--runtime](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) * set handler as "TorchScriptHandler" in model archiver option [--handler](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) @@ -58,49 +58,12 @@ Here is an [example](https://github.com/pytorch/serve/tree/cpp_backend/cpp/test/ torch-model-archiver --model-name mnist_handler --version 1.0 --serialized-file mnist_script.pt --handler libmnist_handler:MnistHandler --runtime LSP ``` Here is an [example](https://github.com/pytorch/serve/tree/cpp_backend/cpp/test/resources/examples/mnist/mnist_handler) of unzipped model mar file. -##### BabyLLama Example -The babyllama example can be found [here](https://github.com/pytorch/serve/blob/master/cpp/src/examples/babyllama/). -To run the example we need to download the weights as well as tokenizer files: -```bash -wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin -wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin -``` -Subsequently, we need to adjust the paths according to our local file structure in [config.json](https://github.com/pytorch/serve/blob/master/serve/cpp/test/resources/examples/babyllama/babyllama_handler/config.json). -```bash -{ -"checkpoint_path" : "/home/ubuntu/serve/cpp/stories15M.bin", -"tokenizer_path" : "/home/ubuntu/serve/cpp/src/examples/babyllama/tokenizer.bin" -} -``` -Then we can create the mar file and deploy it with: -```bash -cd serve/cpp/test/resources/examples/babyllama/babyllama_handler -torch-model-archiver --model-name llm --version 1.0 --handler libbabyllama_handler:BabyLlamaHandler --runtime LSP --extra-files config.json -mkdir model_store && mv llm.mar model_store/ -torchserve --ncs --start --model-store model_store - -curl -v -X POST "http://localhost:8081/models?initial_workers=1&url=llm.mar" -``` -The handler name `libbabyllama_handler:BabyLlamaHandler` consists of our shared library name (as defined in our [CMakeLists.txt](https://github.com/pytorch/serve/blob/master/serve/cpp/src/examples/CMakeLists.txt)) as well as the class name we chose for our [custom handler class](https://github.com/pytorch/serve/blob/master/serve/cpp/src/examples/babyllama/baby_llama_handler.cc) which derives its properties from BaseHandler. -To test the model we can run: -```bash -cd serve/cpp/test/resources/examples/babyllama/ -curl http://localhost:8080/predictions/llm -T prompt.txt -``` -##### Mnist example -* Transform data on client side. For example: -``` -import torch -from PIL import Image -from torchvision import transforms - -image_processing = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) -image = Image.open("examples/image_classifier/mnist/test_data/0.png") -image = image_processing(image) -torch.save(image, "0_png.pt") -``` -* Run model registration and prediction: [Using BaseHandler](serve/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc#L54) or [Using customized handler](serve/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc#L72). +#### Examples +We have created a couple of examples that can get you started with the C++ backend. +The examples are all located under serve/examples/cpp and each comes with a detailed description of how to set it up. +The following examples are available: +* [AOTInductor Llama](../examples/cpp/aot_inductor/llama2/) +* [BabyLlama](../examples/cpp/babyllama/) +* [Llama.cpp](../examples/cpp/llamacpp/) +* [MNIST](../examples/cpp/mnist/) diff --git a/cpp/build.sh b/cpp/build.sh index e9f4a4d3d2..6f6dbf81e9 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -74,40 +74,49 @@ function install_kineto() { } function install_libtorch() { - if [ ! -d "$DEPS_DIR/libtorch" ] ; then - cd "$DEPS_DIR" || exit - if [ "$PLATFORM" = "Mac" ]; then + TORCH_VERSION="2.2.0" + if [ "$PLATFORM" = "Mac" ]; then + if [ ! -d "$DEPS_DIR/libtorch" ]; then if [[ $(uname -m) == 'x86_64' ]]; then echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac x86_64 ${COLOR_OFF}" - wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip - unzip libtorch-macos-x86_64-2.2.0.zip - rm libtorch-macos-x86_64-2.2.0.zip + wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-${TORCH_VERSION}.zip + unzip libtorch-macos-x86_64-${TORCH_VERSION}.zip + rm libtorch-macos-x86_64-${TORCH_VERSION}.zip else echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac arm64 ${COLOR_OFF}" - wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.2.0.zip - unzip libtorch-macos-arm64-2.2.0.zip - rm libtorch-macos-arm64-2.2.0.zip + wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-${TORCH_VERSION}.zip + unzip libtorch-macos-arm64-${TORCH_VERSION}.zip + rm libtorch-macos-arm64-${TORCH_VERSION}.zip fi - elif [ "$PLATFORM" = "Linux" ]; then - echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Linux ${COLOR_OFF}" - if [ "$CUDA" = "cu118" ]; then - wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip - unzip libtorch-cxx11-abi-shared-with-deps-2.1.1+cu118.zip - rm libtorch-cxx11-abi-shared-with-deps-2.1.1+cu118.zip - elif [ "$CUDA" = "cu121" ]; then - wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu121.zip - unzip libtorch-cxx11-abi-shared-with-deps-2.1.1+cu121.zip - rm libtorch-cxx11-abi-shared-with-deps-2.1.1+cu121.zip - else - wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcpu.zip - unzip libtorch-cxx11-abi-shared-with-deps-2.1.1+cpu.zip - rm libtorch-cxx11-abi-shared-with-deps-2.1.1+cpu.zip - fi - elif [ "$PLATFORM" = "Windows" ]; then + fi + elif [ "$PLATFORM" = "Windows" ]; then echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Windows ${COLOR_OFF}" # TODO: Windows echo -e "${COLOR_RED}[ ERROR ] Unknown platform: $PLATFORM ${COLOR_OFF}" exit 1 + else # Linux + if [ -d "$DEPS_DIR/libtorch" ]; then + RAW_VERSION=`cat "$DEPS_DIR/libtorch/build-version"` + VERSION=`cat "$DEPS_DIR/libtorch/build-version" | cut -d "+" -f 1` + if [ "$USE_NIGHTLIES" = "true" ] && [[ ! "${RAW_VERSION}" =~ .*"dev".* ]]; then + rm -rf "$DEPS_DIR/libtorch" + elif [ "$USE_NIGHTLIES" == "" ] && [ "$VERSION" != "$TORCH_VERSION" ]; then + rm -rf "$DEPS_DIR/libtorch" + fi + fi + if [ ! -d "$DEPS_DIR/libtorch" ]; then + cd "$DEPS_DIR" || exit + echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Linux ${COLOR_OFF}" + if [ "$USE_NIGHTLIES" == true ]; then + URL=https://download.pytorch.org/libtorch/nightly/${CUDA}/libtorch-cxx11-abi-shared-with-deps-latest.zip + else + URL=https://download.pytorch.org/libtorch/${CUDA}/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2B${CUDA}.zip + fi + wget $URL + ZIP_FILE=$(basename "$URL") + ZIP_FILE="${ZIP_FILE//%2B/+}" + unzip $ZIP_FILE + rm $ZIP_FILE fi echo -e "${COLOR_GREEN}[ INFO ] libtorch is installed ${COLOR_OFF}" fi @@ -181,10 +190,34 @@ function build_llama_cpp() { BWD=$(pwd) LLAMA_CPP_SRC_DIR=$BASE_DIR/third-party/llama.cpp cd "${LLAMA_CPP_SRC_DIR}" - make LLAMA_METAL=OFF + if [ "$PLATFORM" = "Mac" ]; then + make LLAMA_METAL=OFF -j + else + make -j + fi cd "$BWD" || exit } +function prepare_test_files() { + echo -e "${COLOR_GREEN}[ INFO ]Preparing test files ${COLOR_OFF}" + local EX_DIR="${TR_DIR}/examples/" + rsync -a --link-dest=../../test/resources/ ${BASE_DIR}/test/resources/ ${TR_DIR}/ + if [ ! -f "${EX_DIR}/babyllama/babyllama_handler/tokenizer.bin" ]; then + wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin -O "${EX_DIR}/babyllama/babyllama_handler/tokenizer.bin" + fi + if [ ! -f "${EX_DIR}/babyllama/babyllama_handler/stories15M.bin" ]; then + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin -O "${EX_DIR}/babyllama/babyllama_handler/stories15M.bin" + fi + if [ ! -f "${EX_DIR}/aot_inductor/llama_handler/stories15M.so" ]; then + local HANDLER_DIR=${EX_DIR}/aot_inductor/llama_handler/ + if [ ! -f "${HANDLER_DIR}/stories15M.pt" ]; then + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt?download=true -O "${HANDLER_DIR}/stories15M.pt" + fi + local LLAMA_SO_DIR=${BASE_DIR}/third-party/llama2.so/ + PYTHONPATH=${LLAMA_SO_DIR}:${PYTHONPATH} python ${BASE_DIR}/../examples/cpp/aot_inductor/llama2/compile.py --checkpoint ${HANDLER_DIR}/stories15M.pt ${HANDLER_DIR}/stories15M.so + fi +} + function build() { MAYBE_BUILD_QUIC="" if [ "$WITH_QUIC" == true ] ; then @@ -209,6 +242,11 @@ function build() { MAYBE_CUDA_COMPILER='-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc' fi + MAYBE_NIGHTLIES="-Dnightlies=OFF" + if [ "$USE_NIGHTLIES" == true ]; then + MAYBE_NIGHTLIES="-Dnightlies=ON" + fi + # Build torchserve_cpp with cmake cd "$BWD" || exit YAML_CPP_CMAKE_DIR=$DEPS_DIR/yaml-cpp-build @@ -225,6 +263,7 @@ function build() { "$MAYBE_USE_STATIC_DEPS" \ "$MAYBE_LIB_FUZZING_ENGINE" \ "$MAYBE_CUDA_COMPILER" \ + "$MAYBE_NIGHTLIES" \ .. if [ "$CUDA" = "cu118" ] || [ "$CUDA" = "cu121" ]; then @@ -240,6 +279,7 @@ function build() { "$MAYBE_OVERRIDE_CXX_FLAGS" \ "$MAYBE_USE_STATIC_DEPS" \ "$MAYBE_LIB_FUZZING_ENGINE" \ + "$MAYBE_NIGHTLIES" \ .. export LIBRARY_PATH=${LIBRARY_PATH}:/usr/local/opt/icu4c/lib @@ -296,8 +336,8 @@ WITH_QUIC=false INSTALL_DEPENDENCIES=false PREFIX="" COMPILER_FLAGS="" -CUDA="" -USAGE="./build.sh [-j num_jobs] [-g cu118|cu121] [-q|--with-quic] [-p|--prefix] [-x|--compiler-flags]" +CUDA="cpu" +USAGE="./build.sh [-j num_jobs] [-g cu118|cu121] [-q|--with-quic] [-t|--no-tets] [-p|--prefix] [-x|--compiler-flags] [-n|--nighlies]" while [ "$1" != "" ]; do case $1 in -j | --jobs ) shift @@ -320,6 +360,9 @@ while [ "$1" != "" ]; do shift COMPILER_FLAGS=$1 ;; + -n | --nightlies ) + USE_NIGHTLIES=true + ;; * ) echo $USAGE exit 1 esac @@ -344,8 +387,10 @@ cd $BUILD_DIR || exit BWD=$(pwd) DEPS_DIR=$BWD/_deps LIBS_DIR=$BWD/libs +TR_DIR=$BWD/test/resources/ mkdir -p "$DEPS_DIR" mkdir -p "$LIBS_DIR" +mkdir -p "$TR_DIR" # Must execute from the directory containing this script cd $BASE_DIR @@ -358,6 +403,7 @@ install_libtorch install_yaml_cpp install_sentencepiece build_llama_cpp +prepare_test_files build symlink_torch_libs symlink_yaml_cpp_lib diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 09e9f710e2..603a29a0e0 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -1,6 +1,8 @@ add_subdirectory("../../../examples/cpp/babyllama/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/babyllama/babyllama_handler/") +add_subdirectory("../../../examples/cpp/aot_inductor/llama2/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/llama_handler/") + add_subdirectory("../../../examples/cpp/llamacpp/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/llamacpp/llamacpp_handler/") add_subdirectory("../../../examples/cpp/mnist/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/mnist/mnist_handler/") diff --git a/cpp/test/examples/examples_test.cc b/cpp/test/examples/examples_test.cc index 22254288cc..00e5135715 100644 --- a/cpp/test/examples/examples_test.cc +++ b/cpp/test/examples/examples_test.cc @@ -1,9 +1,11 @@ +#include + #include #include "test/utils/common.hh" TEST_F(ModelPredictTest, TestLoadPredictBabyLlamaHandler) { - std::string base_dir = "test/resources/examples/babyllama/"; + std::string base_dir = "_build/test/resources/examples/babyllama/"; std::string file1 = base_dir + "babyllama_handler/stories15M.bin"; std::string file2 = base_dir + "babyllama_handler/tokenizer.bin"; @@ -21,14 +23,35 @@ TEST_F(ModelPredictTest, TestLoadPredictBabyLlamaHandler) { base_dir + "babyllama_handler", base_dir + "prompt.txt", "llm_ts", 200); } -TEST_F(ModelPredictTest, TestLoadPredictLlmHandler) { - std::string base_dir = "test/resources/examples/llamacpp/"; +TEST_F(ModelPredictTest, TestLoadPredictAotInductorLlamaHandler) { + std::string base_dir = "_build/test/resources/examples/aot_inductor/"; + std::string file1 = base_dir + "llama_handler/stories15M.so"; + std::string file2 = + "_build/test/resources/examples/babyllama/babyllama_handler/" + "tokenizer.bin"; + + std::ifstream f1(file1); + std::ifstream f2(file2); + + if (!f1.good() || !f2.good()) + GTEST_SKIP() << "Skipping TestLoadPredictAotInductorLlamaHandler because " + "of missing files: " + << file1 << " or " << file2; + + this->LoadPredict( + std::make_shared( + base_dir + "llama_handler", "llama", -1, "", "", 1, false), + base_dir + "llama_handler", base_dir + "prompt.txt", "llm_ts", 200); +} + +TEST_F(ModelPredictTest, TestLoadPredictLlamaCppHandler) { + std::string base_dir = "_build/test/resources/examples/llamacpp/"; std::string file1 = base_dir + "llamacpp_handler/llama-2-7b-chat.Q5_0.gguf"; std::ifstream f(file1); if (!f.good()) GTEST_SKIP() - << "Skipping TestLoadPredictLlmHandler because of missing file: " + << "Skipping TestLoadPredictLlamaCppHandler because of missing file: " << file1; this->LoadPredict( diff --git a/cpp/test/resources/examples/aot_inductor/llama_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/examples/aot_inductor/llama_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..6f0f2d5295 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/llama_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,10 @@ +{ + "createdOn": "28/07/2020 06:32:08", + "runtime": "LSP", + "model": { + "modelName": "llama", + "handler": "libllama_so_handler:LlamaHandler", + "modelVersion": "2.0" + }, + "archiverVersion": "0.2.0" +} diff --git a/cpp/test/resources/examples/aot_inductor/llama_handler/config.json b/cpp/test/resources/examples/aot_inductor/llama_handler/config.json new file mode 100644 index 0000000000..04e1cd48ee --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/llama_handler/config.json @@ -0,0 +1,4 @@ +{ +"checkpoint_path" : "_build/test/resources/examples/aot_inductor/llama_handler/stories15M.so", +"tokenizer_path" : "_build/test/resources/examples/babyllama/babyllama_handler/tokenizer.bin" +} diff --git a/cpp/test/resources/examples/aot_inductor/prompt.txt b/cpp/test/resources/examples/aot_inductor/prompt.txt new file mode 100644 index 0000000000..74b56be151 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/prompt.txt @@ -0,0 +1 @@ +Hello my name is diff --git a/cpp/test/resources/examples/babyllama/babyllama_handler/config.json b/cpp/test/resources/examples/babyllama/babyllama_handler/config.json index f75cd1fb53..c88e48143b 100644 --- a/cpp/test/resources/examples/babyllama/babyllama_handler/config.json +++ b/cpp/test/resources/examples/babyllama/babyllama_handler/config.json @@ -1,4 +1,4 @@ { -"checkpoint_path" : "test/resources/examples/babyllama/babyllama_handler/stories15M.bin", -"tokenizer_path" : "test/resources/examples/babyllama/babyllama_handler/tokenizer.bin" +"checkpoint_path" : "_build/test/resources/examples/babyllama/babyllama_handler/stories15M.bin", +"tokenizer_path" : "_build/test/resources/examples/babyllama/babyllama_handler/tokenizer.bin" } diff --git a/cpp/test/resources/examples/llamacpp/llamacpp_handler/config.json b/cpp/test/resources/examples/llamacpp/llamacpp_handler/config.json new file mode 100644 index 0000000000..46169be4ea --- /dev/null +++ b/cpp/test/resources/examples/llamacpp/llamacpp_handler/config.json @@ -0,0 +1,3 @@ +{ + "checkpoint_path" : "_build/test/resources/examples/llamacpp/llamacpp_handler/llama-2-7b-chat.Q5_0.gguf" +} diff --git a/cpp/test/torch_scripted/torch_scripted_test.cc b/cpp/test/torch_scripted/torch_scripted_test.cc index ecb1d7f69f..5f1c986151 100644 --- a/cpp/test/torch_scripted/torch_scripted_test.cc +++ b/cpp/test/torch_scripted/torch_scripted_test.cc @@ -9,44 +9,47 @@ TEST_F(ModelPredictTest, TestLoadPredictBaseHandler) { this->LoadPredict(std::make_shared( - "test/resources/examples/mnist/mnist_handler", + "_build/test/resources/examples/mnist/mnist_handler", "mnist_scripted_v2", -1, "", "", 1, false), - "test/resources/examples/mnist/base_handler", - "test/resources/examples/mnist/0_png.pt", "mnist_ts", 200); + "_build/test/resources/examples/mnist/base_handler", + "_build/test/resources/examples/mnist/0_png.pt", "mnist_ts", + 200); } TEST_F(ModelPredictTest, TestLoadPredictMnistHandler) { this->LoadPredict(std::make_shared( - "test/resources/examples/mnist/mnist_handler", + "_build/test/resources/examples/mnist/mnist_handler", "mnist_scripted_v2", -1, "", "", 1, false), - "test/resources/examples/mnist/mnist_handler", - "test/resources/examples/mnist/0_png.pt", "mnist_ts", 200); + "_build/test/resources/examples/mnist/mnist_handler", + "_build/test/resources/examples/mnist/0_png.pt", "mnist_ts", + 200); } TEST_F(ModelPredictTest, TestBackendInitWrongModelDir) { - auto result = backend_->Initialize("test/resources/examples/mnist"); + auto result = backend_->Initialize("_build/test/resources/examples/mnist"); ASSERT_EQ(result, false); } TEST_F(ModelPredictTest, TestBackendInitWrongHandler) { - auto result = - backend_->Initialize("test/resources/examples/mnist/wrong_handler"); + auto result = backend_->Initialize( + "_build/test/resources/examples/mnist/wrong_handler"); ASSERT_EQ(result, false); } TEST_F(ModelPredictTest, TestLoadModelFailure) { - backend_->Initialize("test/resources/examples/mnist/wrong_model"); + backend_->Initialize("_build/test/resources/examples/mnist/wrong_model"); auto result = backend_->LoadModel(std::make_shared( - "test/resources/examples/mnist/wrong_model", "mnist_scripted_v2", -1, - "", "", 1, false)); + "_build/test/resources/examples/mnist/wrong_model", + "mnist_scripted_v2", -1, "", "", 1, false)); ASSERT_EQ(result->code, 500); } TEST_F(ModelPredictTest, TestLoadPredictMnistHandlerFailure) { this->LoadPredict(std::make_shared( - "test/resources/examples/mnist/mnist_handler", + "_build/test/resources/examples/mnist/mnist_handler", "mnist_scripted_v2", -1, "", "", 1, false), - "test/resources/examples/mnist/mnist_handler", - "test/resources/examples/mnist/0.png", "mnist_ts", 500); + "_build/test/resources/examples/mnist/mnist_handler", + "_build/test/resources/examples/mnist/0.png", "mnist_ts", + 500); } diff --git a/cpp/third-party/llama2.so b/cpp/third-party/llama2.so new file mode 160000 index 0000000000..ac438a5049 --- /dev/null +++ b/cpp/third-party/llama2.so @@ -0,0 +1 @@ +Subproject commit ac438a5049b5f25f473d49c13c9b26f8f5870d54 diff --git a/examples/cpp/aot_inductor/llama2/CMakeLists.txt b/examples/cpp/aot_inductor/llama2/CMakeLists.txt new file mode 100644 index 0000000000..1826330c83 --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library(llama2_so STATIC ../../../../cpp/third-party/llama2.so/run.cpp) +target_compile_options(llama2_so PRIVATE -Wall -Wextra -Ofast -fpermissive) + +add_library(llama_so_handler SHARED src/llama_handler.cc) +target_link_libraries(llama_so_handler PRIVATE llama2_so ts_backends_core ts_utils ${TORCH_LIBRARIES}) diff --git a/examples/cpp/aot_inductor/llama2/README.md b/examples/cpp/aot_inductor/llama2/README.md new file mode 100644 index 0000000000..b3e9261b84 --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/README.md @@ -0,0 +1,87 @@ +This example uses Bert Maher's [llama2.so](https://github.com/bertmaher/llama2.so/) which is a fork of Andrej Karpathy's [llama2.c](https://github.com/karpathy/llama2.c). +It uses AOTInductor to compile the model into an so file which is then executed using libtorch. +The handler C++ source code for this examples can be found [here](src/). + +### Setup +1. Follow the instructions in [README.md](../../../../cpp/README.md) to build the TorchServe C++ backend. + +``` +cd serve/cpp +./builld.sh +``` + +The build script will already create the necessary artifact for this example. +To recreate these by hand you can follow the prepare_test_files function of the [build.sh](../../../../cpp/build.sh) script. +We will need the handler .so file as well as the stories15M.so file containing the model and weights. + +2. Copy the handler file + +```bash +cd ~/serve/examples/cpp/aot_inductor/llama2 +cp ../../../../cpp/_build/test/resources/examples/aot_inductor/llama_handler/libllama_so_handler.so ./ +``` +We will leave the model .so file in place and just use its [path](../../../../cpp/_build/test/resources/examples/aot_inductor/llama_handler/stories15M.so) in the next step. + +4. Create a [config.json](config.json) with the path of the downloaded model and tokenizer: + +```bash +echo '{ +"checkpoint_path" : "/home/ubuntu/serve/cpp/_build/test/resources/examples/aot_inductor/llama_handler/stories15M.so", +"tokenizer_path" : "/home/ubuntu/serve/cpp/_build/test/resources/examples/babyllama/babyllama_handler/tokenizer.bin" +}' > config.json +``` + +The tokenizer is the same we also use for the babyllama example so we can reuse the file from there. + +### Generate MAR file + +Now lets generate the mar file + +```bash +torch-model-archiver --model-name llm --version 1.0 --handler libllama_so_handler:LlamaHandler --runtime LSP --extra-files config.json +``` + +Create model store directory and move the mar file + +``` +mkdir model_store +mv llm.mar model_store/ +``` + +### Inference + +Start torchserve using the following command + +``` +torchserve --ncs --model-store model_store/ +``` + +Register the model using the following command + +``` +curl -v -X POST "http://localhost:8081/models?initial_workers=1&url=llm.mar&batch_size=2&max_batch_delay=5000" +``` + +Infer the model using the following command + +``` +curl http://localhost:8080/predictions/llm -T prompt1.txt +``` + +This example supports batching. To run batch prediction, run the following command + +``` +curl http://localhost:8080/predictions/llm -T prompt1.txt & curl http://localhost:8080/predictions/llm -T prompt2.txt & +``` + +Sample Response + +``` +Hello my name is Daisy. Daisy is three years old. She loves to play with her toys. +One day, Daisy's mommy said, "Daisy, it's time to go to the store." Daisy was so excited! She ran to the store with her mommy. +At the store, Daisy saw a big, red balloon. She wanted it so badly! She asked her mommy, "Can I have the balloon, please?" +Mommy said, "No, Daisy. We don't have enough money for that balloon." +Daisy was sad. She wanted the balloon so much. She started to cry. +Mommy said, "Daisy, don't cry. We can get the balloon. We can buy it and take it home." +Daisy smiled. She was so happy. She hugged her mommy and said, "Thank you, mommy!" +``` diff --git a/examples/cpp/aot_inductor/llama2/compile.py b/examples/cpp/aot_inductor/llama2/compile.py new file mode 100644 index 0000000000..0906e4942f --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/compile.py @@ -0,0 +1,39 @@ +import argparse + +import torch +import torch._export +from model import ModelArgs, Transformer + + +def load_checkpoint(checkpoint): + # load the provided model checkpoint + checkpoint_dict = torch.load(checkpoint, map_location="cpu") + gptconf = ModelArgs(**checkpoint_dict["model_args"]) + model = Transformer(gptconf) + state_dict = checkpoint_dict["model"] + unwanted_prefix = "_orig_mod." + for k, v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + model.load_state_dict(state_dict, strict=False) + model.eval() + return model, gptconf + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "filepath", type=str, default="llama2.so", help="the output filepath" + ) + parser.add_argument("--checkpoint", type=str, help="checkpoint .pt") + args = parser.parse_args() + model, config = load_checkpoint(args.checkpoint) + x = torch.randint(0, config.vocab_size, (1, config.max_seq_len // 2)) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=config.max_seq_len) + torch._C._GLIBCXX_USE_CXX11_ABI = True + so_path = torch._export.aot_compile( + model, + (x,), + dynamic_shapes={"tokens": (None, seq_len_dim)}, + options={"aot_inductor.output_path": args.filepath}, + ) diff --git a/examples/cpp/aot_inductor/llama2/config.json b/examples/cpp/aot_inductor/llama2/config.json new file mode 100644 index 0000000000..9d8728df8f --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/config.json @@ -0,0 +1,4 @@ +{ +"checkpoint_path" : "/home/ubuntu/serve/cpp/_build/test/resources/examples/aot_inductor/llama_handler/stories15M.so", +"tokenizer_path" : "/home/ubuntu/serve/cpp/_build/test/resources/examples/babyllama/babyllama_handler/tokenizer.bin" +} diff --git a/examples/cpp/aot_inductor/llama2/prompt1.txt b/examples/cpp/aot_inductor/llama2/prompt1.txt new file mode 100644 index 0000000000..baa5a1abbf --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/prompt1.txt @@ -0,0 +1 @@ +Hello my name is Dan diff --git a/examples/cpp/aot_inductor/llama2/prompt2.txt b/examples/cpp/aot_inductor/llama2/prompt2.txt new file mode 100644 index 0000000000..99568648e9 --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/prompt2.txt @@ -0,0 +1 @@ +Hello my name is Daisy diff --git a/examples/cpp/aot_inductor/llama2/src/llama2.so/LICENSE b/examples/cpp/aot_inductor/llama2/src/llama2.so/LICENSE new file mode 100644 index 0000000000..2ad12227f9 --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/src/llama2.so/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Andrej + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/cpp/aot_inductor/llama2/src/llama2.so/llama2.hh b/examples/cpp/aot_inductor/llama2/src/llama2.so/llama2.hh new file mode 100644 index 0000000000..ef8bbe8f99 --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/src/llama2.so/llama2.hh @@ -0,0 +1,72 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// ---------------------------------------------------------------------------- +// Transformer model + +typedef struct { + int vocab_size; // vocabulary size, usually 256 (byte-level) + int seq_len; // max sequence length +} Config; + +typedef struct { + float *logits; // output logits + int64_t* toks; // tokens seen so far; no kv-cache :( +} RunState; + +typedef struct { + Config config; // the hyperparameters of the architecture (the blueprint) + RunState state; // buffers for the "wave" of activations in the forward pass + torch::inductor::AOTIModelContainerRunnerCpu *runner; +} Transformer; +// ---------------------------------------------------------------------------- +// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens + +typedef struct { + char *str; + int id; +} TokenIndex; + +typedef struct { + char** vocab; + float* vocab_scores; + TokenIndex *sorted_vocab; + int vocab_size; + unsigned int max_token_length; + unsigned char byte_pieces[512]; // stores all single-byte strings +} Tokenizer; + +// ---------------------------------------------------------------------------- +// The Sampler, which takes logits and returns a sampled token +// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling + +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling + +typedef struct { + int vocab_size; + ProbIndex* probindex; // buffer used in top-p sampling + float temperature; + float topp; + unsigned long long rng_state; +} Sampler; +void build_transformer(Transformer *t, char* checkpoint_path, int vocab_size, int seq_len); +void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size); +void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed); +void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens); +float* forward(Transformer* transformer, int token, int pos); +int sample(Sampler* sampler, float* logits); +long time_in_ms(); +char* decode(Tokenizer* t, int prev_token, int token); +void free_sampler(Sampler* sampler); +void free_tokenizer(Tokenizer* t); +void free_transformer(Transformer* t); diff --git a/examples/cpp/aot_inductor/llama2/src/llama_handler.cc b/examples/cpp/aot_inductor/llama2/src/llama_handler.cc new file mode 100644 index 0000000000..40f6c67b08 --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/src/llama_handler.cc @@ -0,0 +1,302 @@ +#include "llama_handler.hh" + +#include +#include + +#include + +#include "llama2.so/llama2.hh" + + +namespace llm { + +Transformer transformer; +Tokenizer tokenizer; +Sampler sampler; +int steps = 256; + +std::pair, std::shared_ptr> +LlamaHandler::LoadModel( + std::shared_ptr &load_model_request) { + try { + auto device = GetTorchDevice(load_model_request); + + const std::string configFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "config.json"); + std::string jsonContent; + if (!folly::readFile(configFilePath.c_str(), jsonContent)) { + std::cerr << "config.json not found at: " << configFilePath << std::endl; + throw; + } + folly::dynamic json; + json = folly::parseJson(jsonContent); + std::string checkpoint_path; + std::string tokenizer_path; + if (json.find("checkpoint_path") != json.items().end() && + json.find("tokenizer_path") != json.items().end()) { + checkpoint_path = json["checkpoint_path"].asString(); + tokenizer_path = json["tokenizer_path"].asString(); + } else { + std::cerr + << "Required fields 'model_name' and 'model_path' not found in JSON." + << std::endl; + throw; + } + + build_transformer(&transformer, const_cast(checkpoint_path.c_str()), 32000, 256); + + build_tokenizer(&tokenizer, const_cast(tokenizer_path.c_str()), + transformer.config.vocab_size); + + float temperature = + 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, + // but slower + unsigned long long rng_seed(0); + // build the Sampler + build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, + rng_seed); + + return std::make_pair(nullptr, device); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} + +c10::IValue LlamaHandler::Preprocess( + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &request_batch, + std::shared_ptr &response_batch) { + auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + std::vector batch_tensors; + uint8_t idx = 0; + for (auto &request : *request_batch) { + try { + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + + auto data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = request.headers.find( + torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + std::string msg = torchserve::Converter::VectorToStr(data_it->second); + + int num_prompt_tokens = 0; + + std::unique_ptr msgCStr( + new char[msg.size() + 1], [](char *ptr) { delete[] ptr; }); + + std::strcpy(msgCStr.get(), msg.c_str()); + + std::unique_ptr prompt_tokens(new int[msg.length() + 3]); + + encode(&tokenizer, msgCStr.get(), 1, 0, prompt_tokens.get(), + &num_prompt_tokens); + + std::vector tensor_vector; + for (int64_t i = 0; i < num_prompt_tokens; ++i) { + int token = prompt_tokens[i]; + torch::Tensor tensor = torch::tensor(token, torch::kInt64); + tensor_vector.push_back(tensor); + } + batch_ivalue.emplace_back(torch::stack(tensor_vector)); + + idx_to_req_id.second[idx++] = request.request_id; + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error:{}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } + + return batch_ivalue; +} + +c10::IValue LlamaHandler::Inference( + std::shared_ptr model, c10::IValue &inputs, + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + torch::InferenceMode guard; + auto batch_output_vector = c10::impl::GenericList(torch::TensorType::get()); + long batch_token_length = 0; + long start = + 0; // used to time our code, only initialized after first iteration + + try { + for (auto input : inputs.toTensorList()) { + std::vector tensor_vector; + tensor_vector.reserve(steps); + torch::Tensor tokens_list_tensor = input.get().toTensor(); + + int64_t num_elements = tokens_list_tensor.numel(); + + int64_t *data_ptr = tokens_list_tensor.data_ptr(); + + std::unique_ptr prompt_tokens(new int[num_elements]); + + for (int64_t i = 0; i < num_elements; ++i) { + prompt_tokens[i] = data_ptr[i]; + } + + // start the main loop + int next; // will store the next token in the sequence + int token = + prompt_tokens[0]; // kick off with the first token in the prompt + int pos = 0; // position in the sequence + while (pos < steps) { + // forward the transformer to get logits for the next token + float *logits = forward(&transformer, token, pos); + + // advance the state state machine + if (pos < num_elements - 1) { + // if we are still processing the input prompt, force the next prompt + // token + next = prompt_tokens[pos + 1]; + } else { + // otherwise sample the next token from the logits + next = sample(&sampler, logits); + } + pos++; + + torch::Tensor tensor = torch::tensor(next, torch::kLong); + tensor_vector.push_back(tensor); + + // data-dependent terminating condition: the BOS (=1) token delimits + // sequences + if (next == 1) { + break; + } + token = next; + + // init the timer here because the first iteration can be slower + if (start == 0) { + start = time_in_ms(); + } + } + batch_token_length = batch_token_length + pos - 1; + + torch::Tensor stacked_tensor = torch::stack(tensor_vector); + + batch_output_vector.push_back(stacked_tensor); + } + + TS_LOGF(DEBUG, "Total number of tokens generated: {}", batch_token_length); + if (batch_token_length > 1) { + long end = time_in_ms(); + double token_per_sec = batch_token_length / (double)(end - start) * 1000; + TS_LOGF(DEBUG, "Achieved tok per sec: {}", token_per_sec); + } + } catch (std::runtime_error &e) { + TS_LOG(ERROR, e.what()); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, "Failed to apply inference on input, c10 error:{}", e.msg()); + } catch (...) { + TS_LOG(ERROR, "Failed to run inference on this batch"); + } + return batch_output_vector; +} + +void LlamaHandler::Postprocess( + c10::IValue &outputs, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + auto data = outputs.toTensorList(); + for (const auto &kv : idx_to_req_id.second) { + try { + int64_t num_elements = data[kv.first].get().toTensor().numel(); + int64_t *data_ptr = data[kv.first].get().toTensor().data_ptr(); + int64_t token = 1; + std::string concatenated_string; + for (int64_t i = 0; i < num_elements; ++i) { + char *piece = decode(&tokenizer, token, data_ptr[i]); + std::string piece_string(piece); + token = data_ptr[i]; + concatenated_string += piece_string; + } + + TS_LOGF(DEBUG, "Generated String: {}", concatenated_string); + + auto response = (*response_batch)[kv.second]; + + response->SetResponse(200, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + concatenated_string); + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + kv.second, e.what()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to postprocess tensor"); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, + "Failed to postprocess tensor for request id: {}, error: {}", + kv.second, e.msg()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to postprocess tensor"); + } + } +} + +LlamaHandler::~LlamaHandler() noexcept { + free_sampler(&sampler); + free_tokenizer(&tokenizer); + free_transformer(&transformer); +} + +} // namespace llm + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::BaseHandler *allocatorLlamaHandler() { + return new llm::LlamaHandler(); +} + +void deleterLlamaHandler(torchserve::BaseHandler *p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/examples/cpp/aot_inductor/llama2/src/llama_handler.hh b/examples/cpp/aot_inductor/llama2/src/llama_handler.hh new file mode 100644 index 0000000000..7c2c067acc --- /dev/null +++ b/examples/cpp/aot_inductor/llama2/src/llama_handler.hh @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "src/backends/handler/base_handler.hh" + +namespace llm { +class LlamaHandler : public torchserve::BaseHandler { + public: + // NOLINTBEGIN(bugprone-exception-escape) + LlamaHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~LlamaHandler() noexcept; + + void initialize_context(); + + std::pair, std::shared_ptr> LoadModel( + std::shared_ptr& load_model_request) + override; + + c10::IValue Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + c10::IValue Inference( + std::shared_ptr model, c10::IValue& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + + void Postprocess( + c10::IValue& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; +}; +} // namespace llm diff --git a/examples/cpp/babyllama/README.md b/examples/cpp/babyllama/README.md index cba4df5cd5..cd68eec93a 100644 --- a/examples/cpp/babyllama/README.md +++ b/examples/cpp/babyllama/README.md @@ -1,3 +1,5 @@ +## BabyLlama example + This example is adapted from https://github.com/karpathy/llama2.c. The handler C++ source code for this examples can be found [here](../../../cpp/src/examples/babyllama/). ### Setup @@ -30,7 +32,7 @@ echo '{ While building the C++ backend the `libbabyllama_handler.so` file is generated in the [babyllama_handler](../../../cpp/test/resources/examples/babyllama/babyllama_handler) folder. ```bash -cp ../../../cpp/test/resources/examples/babyllama/babyllama_handler/libbabyllama_handler.so ./ +cp ../../../cpp/_build/test/resources/examples/babyllama/babyllama_handler/libbabyllama_handler.so ./ ``` ### Generate MAR file diff --git a/examples/cpp/llamacpp/README.md b/examples/cpp/llamacpp/README.md index 8221262858..f0ab891e52 100644 --- a/examples/cpp/llamacpp/README.md +++ b/examples/cpp/llamacpp/README.md @@ -1,3 +1,5 @@ +## Llama.cpp example + This example used [llama.cpp](https://github.com/ggerganov/llama.cpp) to deploy a Llama-2-7B-Chat model using the TorchServe C++ backend. The handler C++ source code for this examples can be found [here](../../../cpp/src/examples/llamacpp/). @@ -29,7 +31,7 @@ echo '{ While building the C++ backend the `libllamacpp_handler.so` file is generated in the [llamacpp_handler](../../../cpp/test/resources/examples/llamacpp/llamacpp_handler) folder. ```bash -cp ../../../cpp/test/resources/examples/llamacpp/llamacpp_handler/libllamacpp_handler.so ./ +cp ../../../cpp/_build/test/resources/examples/llamacpp/llamacpp_handler/libllamacpp_handler.so ./ ``` ### Generate MAR file diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index d9133f66a9..a6c91631a1 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1170,9 +1170,13 @@ bfloat bb babyllama libbabyllama -BabyLLama +BabyLlama BabyLlamaHandler CMakeLists TorchScriptHandler libllamacpp +libtorch +Andrej +Karpathy's +Maher's warmup