diff --git a/examples/rkllm/.gitignore b/examples/rkllm/.gitignore new file mode 100644 index 0000000..f939318 --- /dev/null +++ b/examples/rkllm/.gitignore @@ -0,0 +1 @@ +!librkllmrt.so diff --git a/examples/rkllm/README.md b/examples/rkllm/README.md new file mode 100644 index 0000000..880748c --- /dev/null +++ b/examples/rkllm/README.md @@ -0,0 +1,135 @@ +# RKLLM Workflow (Conversion + C++/Python Inference on Axon) + +This `rkllm/` folder contains: +- `convert.py`: converts a Hugging Face model directory to `.rkllm` +- `inference.cpp`: RKLLM C++ runtime inference app +- `inference.py`: RKLLM Python runtime inference app (ctypes wrapper) +- `dataset.json`: optional calibration dataset +- `rkllm.h` + `librkllmrt.so`: RKLLM C++/Python runtime build dependencies + +Recommended flow: +1. Convert once on host machine +2. Copy the generated `.rkllm` model to Axon +3. Run inference on Axon using either C++ or Python + +## 0) Get Started + +```bash +git clone https://github.com/vicharak-in/Axon-NPU-Guide.git +cd Axon-NPU-Guide/rkllm +``` + +`rkllm/` is the working folder for this guide and includes the required runtime files. + +## 1) Common Conversion + +### 1.1 Create environment + get toolkit + +```bash +python3 -m venv venv-rkllm +source venv-rkllm/bin/activate + +git clone https://github.com/airockchip/rknn-llm.git +``` + +If using Python 3.12: + +```bash +export BUILD_CUDA_EXT=0 +pip install rknn-llm/rkllm-toolkit/packages/rkllm_toolkit-1.2.3-cp312-cp312-linux_x86_64.whl +``` + +If you hit `No module named pkg_resources`: + +```bash +pip install "setuptools==68.0.0" +``` + +### 1.2 Download model from Hugging Face + +```bash +sudo apt install -y git-lfs +git lfs install + +git clone https://huggingface.co/Qwen/Qwen3-0.6B +# Example alternative: +# git clone https://huggingface.co/Qwen/Qwen2-1.5B +``` + +### 1.3 Convert to RKLLM + +Qwen3-0.6B example: + +```bash +python3 convert.py -i ./Qwen3-0.6B -o --device cpu --dtype float32 --quantized-dtype w8a8 --quantized-algorithm normal --optimization-level 1 --num-npu-core 3 --target-platform rk3588 --max-context 4096 +``` +Add the flag: --dataset when using a calibration dataset. + +Notes: +- Use `--dataset dataset.json` to enable calibration dataset quantization. +- `--max-context` must be `>0`, `<=16384`, and a multiple of `32`. +- `--quantized-algorithm grq/gdq` requires `--device cuda` in `convert.py`. + +After conversion, copy only the generated `.rkllm` model file to your Axon `rkllm/` folder. + +--- + +## 2) C++ Inference on Axon + +### 2.1 Compile + +```bash +g++ -O2 -std=c++17 -I. inference.cpp -L. -lrkllmrt -Wl,-rpath,'$ORIGIN' -o inference +``` +> Note: keep the `librkllmrt.so` file and the `rkllm.h` file in the same directory as the inference.cpp file for the above command to work. + +### 2.2 Run + +```bash +./inference --model --target-platform rk3588 --stream --print-perf --keep-history +``` + +Useful behavior: +- If `--prompt` is not passed, it starts interactive mode. +- Interactive commands: + - `exit` + - `clear` (clears KV cache) + +--- + +## 3) Python Inference on Axon (with venv) + +`inference.py` uses only Python stdlib + `librkllmrt.so`, so a lightweight venv is enough. + +### 3.1 Create env + +```bash +python3 -m venv venv-rkllm +source venv-rkllm/bin/activate +``` + +### 3.2 Inference + +Single-shot prompt + +```bash +python3 inference.py -m --target-platform rk3588 --stream --print-perf --prompt "Hello" +``` + +Keep chat history across turns (interactive mode): + +```bash +python3 inference.py -m --target-platform rk3588 --stream --print-perf --keep-history +``` + +--- + +## 4) Troubleshooting + +- `ModuleNotFoundError: No module named 'pkg_resources'` + - Run: `pip install "setuptools==68.0.0"` + +- `OSError: librkllmrt.so: cannot open shared object file` + - Pass `--runtime-lib /full/path/librkllmrt.so` or set `LD_LIBRARY_PATH`. + - Confirm you are using the Linux `aarch64` runtime on Axon. +--- \ No newline at end of file diff --git a/examples/rkllm/dataset.json b/examples/rkllm/dataset.json new file mode 100644 index 0000000..9373b2f --- /dev/null +++ b/examples/rkllm/dataset.json @@ -0,0 +1,42 @@ +[ +{"input":"Explain gravity in simple terms.","target":"Gravity is the force that attracts objects with mass toward each other."}, +{"input":"What is artificial intelligence?","target":"Artificial intelligence is the field of creating machines that can perform tasks requiring human-like intelligence."}, +{"input":"Describe how photosynthesis works.","target":"Photosynthesis is the process plants use to convert sunlight, water, and carbon dioxide into glucose and oxygen."}, +{"input":"What causes earthquakes?","target":"Earthquakes occur when tectonic plates suddenly shift along faults in the Earth's crust."}, +{"input":"Explain the purpose of the Internet.","target":"The Internet connects computers worldwide to share information, services, and communication."}, +{"input":"What is machine learning?","target":"Machine learning is a branch of AI that enables computers to learn patterns from data without explicit programming."}, +{"input":"Define climate change.","target":"Climate change refers to long-term shifts in temperature and weather patterns caused largely by human activity."}, +{"input":"Explain the concept of black holes.","target":"Black holes are extremely dense regions of space where gravity is so strong that nothing can escape."}, +{"input":"What is a computer algorithm?","target":"An algorithm is a step-by-step procedure used to solve a problem or perform a computation."}, +{"input":"Describe the water cycle.","target":"The water cycle describes how water evaporates, condenses into clouds, and returns to Earth as precipitation."}, +{"input":"What is the purpose of education?","target":"Education helps people gain knowledge, develop skills, and understand the world around them."}, +{"input":"Explain renewable energy.","target":"Renewable energy comes from sources like sunlight, wind, and water that naturally replenish."}, +{"input":"What is quantum computing?","target":"Quantum computing uses quantum mechanics principles like superposition and entanglement to process information."}, +{"input":"Describe neural networks.","target":"Neural networks are machine learning models inspired by the brain that learn patterns through layers of interconnected nodes."}, +{"input":"What is data science?","target":"Data science combines statistics, computing, and domain knowledge to extract insights from data."}, +{"input":"Explain natural language processing.","target":"Natural language processing allows computers to understand, interpret, and generate human language."}, +{"input":"What is the role of satellites?","target":"Satellites orbit Earth to provide communication, navigation, weather monitoring, and scientific observation."}, +{"input":"Describe the solar system.","target":"The solar system consists of the Sun and the celestial bodies that orbit it, including planets and asteroids."}, +{"input":"What is cybersecurity?","target":"Cybersecurity protects computer systems and networks from attacks, theft, and damage."}, +{"input":"Explain blockchain technology.","target":"Blockchain is a decentralized digital ledger that securely records transactions across many computers."}, +{"input":"What is robotics?","target":"Robotics is the engineering field focused on designing and building machines that perform automated tasks."}, +{"input":"Describe cloud computing.","target":"Cloud computing provides computing resources like servers and storage over the internet."}, +{"input":"What is big data?","target":"Big data refers to extremely large datasets that require advanced tools for processing and analysis."}, +{"input":"Explain computer vision.","target":"Computer vision enables machines to interpret and understand visual information from images and videos."}, +{"input":"What is deep learning?","target":"Deep learning is a subset of machine learning using multi-layer neural networks to model complex patterns."}, +{"input":"Describe autonomous vehicles.","target":"Autonomous vehicles use sensors, AI, and control systems to navigate without human drivers."}, +{"input":"What is edge computing?","target":"Edge computing processes data closer to where it is generated to reduce latency and bandwidth usage."}, +{"input":"Explain the concept of sustainability.","target":"Sustainability involves using resources responsibly so future generations can also meet their needs."}, +{"input":"What is genetic engineering?","target":"Genetic engineering involves modifying the DNA of organisms to achieve specific traits."}, +{"input":"Describe the function of DNA.","target":"DNA stores genetic instructions used in the growth and functioning of living organisms."}, +{"input":"What is a database?","target":"A database is an organized collection of structured information that can be easily accessed and managed."}, +{"input":"Explain operating systems.","target":"An operating system manages hardware resources and provides services for computer programs."}, +{"input":"What is virtualization?","target":"Virtualization allows multiple virtual machines to run on a single physical computer."}, +{"input":"Describe the purpose of APIs.","target":"APIs allow software systems to communicate and exchange data with each other."}, +{"input":"What is an embedded system?","target":"An embedded system is a specialized computer designed to perform dedicated functions within larger devices."}, +{"input":"Explain sensor technology.","target":"Sensors detect physical signals like temperature or light and convert them into measurable data."}, +{"input":"What is satellite imaging?","target":"Satellite imaging captures images of Earth using sensors mounted on orbiting satellites."}, +{"input":"Describe machine perception.","target":"Machine perception enables computers to interpret sensory data such as images, sound, or motion."}, +{"input":"What is distributed computing?","target":"Distributed computing uses multiple computers working together to solve large computational problems."}, +{"input":"Explain artificial neural networks.","target":"Artificial neural networks are computational systems inspired by biological neural networks."} +] \ No newline at end of file diff --git a/examples/rkllm/inference.cpp b/examples/rkllm/inference.cpp new file mode 100644 index 0000000..1b22ae7 --- /dev/null +++ b/examples/rkllm/inference.cpp @@ -0,0 +1,458 @@ +#include +#include "rkllm.h" +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + std::string model_path; + std::string target_platform = "rk3588"; + std::string prompt; + std::string role = "user"; + + int max_new_tokens = 512; + int max_context_len = 4096; + int top_k = 1; + float top_p = 0.95f; + float temperature = 0.8f; + float repeat_penalty = 1.1f; + float frequency_penalty = 0.0f; + float presence_penalty = 0.0f; + int base_domain_id = 0; + int enabled_cpus_num = 4; + + bool stream = false; + bool print_perf = false; + bool keep_history = false; + bool enable_thinking = false; + bool embed_flash = true; + bool keep_special_tokens = false; + + std::string lora_model_path; + std::string lora_name = "default_lora"; + float lora_scale = 1.0f; + + std::string prompt_cache_load_path; + std::string prompt_cache_save_path; + + std::string system_prompt; + std::string chat_template_prefix; + std::string chat_template_postfix; +}; + +struct CallbackState { + bool stream = false; + bool has_error = false; + std::string text; + RKLLMPerfStat perf{}; +}; + +class RKLLMApp { +public: + explicit RKLLMApp(Options options) : options_(std::move(options)) {} + + ~RKLLMApp() { + if (handle_ != nullptr) { + rkllm_destroy(handle_); + handle_ = nullptr; + } + } + + void Init() { + RKLLMParam param = rkllm_createDefaultParam(); + param.model_path = options_.model_path.c_str(); + param.max_new_tokens = options_.max_new_tokens; + param.max_context_len = options_.max_context_len; + param.top_k = options_.top_k; + param.top_p = options_.top_p; + param.temperature = options_.temperature; + param.repeat_penalty = options_.repeat_penalty; + param.frequency_penalty = options_.frequency_penalty; + param.presence_penalty = options_.presence_penalty; + param.skip_special_token = !options_.keep_special_tokens; + param.is_async = false; + param.extend_param.base_domain_id = options_.base_domain_id; + param.extend_param.embed_flash = options_.embed_flash ? 1 : 0; + param.extend_param.n_batch = 1; + param.extend_param.use_cross_attn = 0; + param.extend_param.enabled_cpus_num = static_cast(options_.enabled_cpus_num); + param.extend_param.enabled_cpus_mask = BuildCpuMask( + options_.target_platform, + options_.enabled_cpus_num + ); + + const int ret = rkllm_init(&handle_, ¶m, Callback); + if (ret != 0) { + throw std::runtime_error("rkllm_init failed with code " + std::to_string(ret)); + } + + if (!options_.lora_model_path.empty()) { + RKLLMLoraAdapter adapter{}; + adapter.lora_adapter_path = options_.lora_model_path.c_str(); + adapter.lora_adapter_name = options_.lora_name.c_str(); + adapter.scale = options_.lora_scale; + const int lora_ret = rkllm_load_lora(handle_, &adapter); + if (lora_ret != 0) { + throw std::runtime_error("rkllm_load_lora failed with code " + std::to_string(lora_ret)); + } + lora_param_.lora_adapter_name = options_.lora_name.c_str(); + lora_loaded_ = true; + } + + if (!options_.prompt_cache_load_path.empty()) { + const int cache_ret = rkllm_load_prompt_cache(handle_, options_.prompt_cache_load_path.c_str()); + if (cache_ret != 0) { + throw std::runtime_error("rkllm_load_prompt_cache failed with code " + std::to_string(cache_ret)); + } + } + + if (!options_.system_prompt.empty() || + !options_.chat_template_prefix.empty() || + !options_.chat_template_postfix.empty()) { + const int tpl_ret = rkllm_set_chat_template( + handle_, + options_.system_prompt.c_str(), + options_.chat_template_prefix.c_str(), + options_.chat_template_postfix.c_str() + ); + if (tpl_ret != 0) { + throw std::runtime_error("rkllm_set_chat_template failed with code " + std::to_string(tpl_ret)); + } + } + } + + std::string Generate(const std::string& prompt, const std::string& role, bool enable_thinking) { + callback_state_.stream = options_.stream; + callback_state_.has_error = false; + callback_state_.text.clear(); + last_generation_wall_time_s_ = 0.0; + std::memset(&callback_state_.perf, 0, sizeof(callback_state_.perf)); + + RKLLMInput input{}; + input.role = role.c_str(); + input.enable_thinking = enable_thinking; + input.input_type = RKLLM_INPUT_PROMPT; + input.prompt_input = prompt.c_str(); + + RKLLMInferParam infer_param{}; + infer_param.mode = RKLLM_INFER_GENERATE; + infer_param.keep_history = options_.keep_history ? 1 : 0; + infer_param.lora_params = lora_loaded_ ? &lora_param_ : nullptr; + + RKLLMPromptCacheParam prompt_cache_param{}; + if (!options_.prompt_cache_save_path.empty()) { + prompt_cache_param.save_prompt_cache = 1; + prompt_cache_param.prompt_cache_path = options_.prompt_cache_save_path.c_str(); + infer_param.prompt_cache_params = &prompt_cache_param; + } else { + infer_param.prompt_cache_params = nullptr; + } + + const auto started_at = std::chrono::steady_clock::now(); + const int ret = rkllm_run(handle_, &input, &infer_param, &callback_state_); + const auto ended_at = std::chrono::steady_clock::now(); + last_generation_wall_time_s_ = std::chrono::duration(ended_at - started_at).count(); + if (ret != 0) { + throw std::runtime_error("rkllm_run failed with code " + std::to_string(ret)); + } + if (callback_state_.has_error) { + throw std::runtime_error("runtime callback reported RKLLM_RUN_ERROR"); + } + return callback_state_.text; + } + + void ClearKvCache(bool keep_system_prompt) { + const int ret = rkllm_clear_kv_cache(handle_, keep_system_prompt ? 1 : 0, nullptr, nullptr); + if (ret != 0) { + throw std::runtime_error("rkllm_clear_kv_cache failed with code " + std::to_string(ret)); + } + } + + void Abort() { + if (handle_ != nullptr) { + rkllm_abort(handle_); + } + } + + void PrintPerf() const { + const auto& perf = callback_state_.perf; + const double decode_tok_per_s = + perf.generate_time_ms > 0.0f ? + static_cast(perf.generate_tokens) / (static_cast(perf.generate_time_ms) / 1000.0) : + 0.0; + + const double actual_tok_per_s = + last_generation_wall_time_s_ > 0.0 ? + static_cast(perf.generate_tokens) / last_generation_wall_time_s_ : + 0.0; + + std::cout << std::fixed << std::setprecision(2) + << "[perf] prefill=" << perf.prefill_time_ms << "ms/" << perf.prefill_tokens + << " tok | decode=" << perf.generate_time_ms << "ms/" << perf.generate_tokens + << " tok (" << decode_tok_per_s << " tok/s)" + << " | actual=" << actual_tok_per_s << " tok/s" + << " | mem=" << perf.memory_usage_mb << " MB\n"; + std::cout.unsetf(std::ios::floatfield); + std::cout << std::setprecision(6); + } + +private: + static int MaxCpusForPlatform(const std::string& platform) { + if (platform == "rk3588" || platform == "rk3576") { + return 8; + } + if (platform == "rk3562" || platform == "rv1126b") { + return 4; + } + return 8; + } + + static uint32_t BuildCpuMask(const std::string& platform, int enabled_cpus_num) { + int start_cpu = 0; + + if (platform == "rk3588" || platform == "rk3576") { + start_cpu = 4; // use BIG cores only + } + + uint32_t mask = 0; + for (int cpu = start_cpu; cpu < start_cpu + enabled_cpus_num; ++cpu) { + mask |= (1u << cpu); + } + return mask; + } + + static int Callback(RKLLMResult* result, void* userdata, LLMCallState state) { + auto* cb = static_cast(userdata); + if (cb == nullptr) { + return 0; + } + + if (state == RKLLM_RUN_ERROR) { + cb->has_error = true; + std::cerr << "\n[error] RKLLM runtime callback returned RKLLM_RUN_ERROR\n"; + return 0; + } + + if (state == RKLLM_RUN_FINISH) { + if (result != nullptr) { + cb->perf = result->perf; + } + if (cb->stream) { + std::cout << std::endl; + } + return 0; + } + + if (state == RKLLM_RUN_NORMAL || state == RKLLM_RUN_WAITING) { + if (result != nullptr && result->text != nullptr) { + cb->text += result->text; + if (cb->stream) { + std::cout << result->text << std::flush; + } + } + } + return 0; + } + + Options options_; + LLMHandle handle_ = nullptr; + CallbackState callback_state_{}; + RKLLMLoraParam lora_param_{}; + bool lora_loaded_ = false; + double last_generation_wall_time_s_ = 0.0; +}; + +RKLLMApp* g_app = nullptr; + +void SignalHandler(int signal_num) { + if (g_app != nullptr) { + g_app->Abort(); + } + std::cerr << "\nInterrupted by signal " << signal_num << ". Exiting.\n"; + std::_Exit(130); +} + +void PrintUsage(const char* bin) { + std::cout + << "Usage:\n" + << " " << bin << " --model [--prompt \"Hello\"] [options]\n\n" + << "Key options:\n" + << " --model Required rkllm model path\n" + << " --target-platform rk3588|rk3576|rk3562|rv1126b (default rk3588)\n" + << " --prompt Single-shot prompt; if omitted, interactive mode\n" + << " --stream Stream token output in callback\n" + << " --print-perf Print perf stats after each run\n" + << " --keep-history Keep conversation history (multi-turn)\n" + << " --enable-thinking Set enable_thinking=true in RKLLMInput\n" + << " --role Input role (default user)\n" + << " --max-new-tokens Default 512\n" + << " --max-context-len Default 4096\n" + << " --enabled-cpus-num Default 4; rk3588/rk3576: 1-8, rk3562/rv1126b: 1-4\n" + << " --lora-model Optional LoRA model path\n" + << " --lora-name Optional LoRA adapter name (default default_lora)\n" + << " --lora-scale Optional LoRA scale (default 1.0)\n" + << " --prompt-cache-load Optional prompt cache preload\n" + << " --prompt-cache-save Optional prompt cache save path\n" + << " --system-prompt Optional system prompt for custom template\n" + << " --chat-template-prefix Optional custom prompt prefix\n" + << " --chat-template-postfix Optional custom prompt postfix\n" + << " --help Show this help\n\n" + << "Interactive commands:\n" + << " exit Exit program\n" + << " clear Clear KV cache\n"; +} + +bool NextValue(int argc, char** argv, int& i, std::string& out) { + if (i + 1 >= argc) { + return false; + } + out = argv[++i]; + return true; +} + +Options ParseArgs(int argc, char** argv) { + Options opt; + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + std::string value; + if (arg == "--help" || arg == "-h") { + PrintUsage(argv[0]); + std::exit(0); + } else if (arg == "--model" && NextValue(argc, argv, i, value)) { + opt.model_path = value; + } else if (arg == "--target-platform" && NextValue(argc, argv, i, value)) { + opt.target_platform = value; + } else if (arg == "--prompt" && NextValue(argc, argv, i, value)) { + opt.prompt = value; + } else if (arg == "--role" && NextValue(argc, argv, i, value)) { + opt.role = value; + } else if (arg == "--max-new-tokens" && NextValue(argc, argv, i, value)) { + opt.max_new_tokens = std::stoi(value); + } else if (arg == "--max-context-len" && NextValue(argc, argv, i, value)) { + opt.max_context_len = std::stoi(value); + } else if (arg == "--top-k" && NextValue(argc, argv, i, value)) { + opt.top_k = std::stoi(value); + } else if (arg == "--top-p" && NextValue(argc, argv, i, value)) { + opt.top_p = std::stof(value); + } else if (arg == "--temperature" && NextValue(argc, argv, i, value)) { + opt.temperature = std::stof(value); + } else if (arg == "--repeat-penalty" && NextValue(argc, argv, i, value)) { + opt.repeat_penalty = std::stof(value); + } else if (arg == "--frequency-penalty" && NextValue(argc, argv, i, value)) { + opt.frequency_penalty = std::stof(value); + } else if (arg == "--presence-penalty" && NextValue(argc, argv, i, value)) { + opt.presence_penalty = std::stof(value); + } else if (arg == "--base-domain-id" && NextValue(argc, argv, i, value)) { + opt.base_domain_id = std::stoi(value); + } else if (arg == "--enabled-cpus-num" && NextValue(argc, argv, i, value)) { + opt.enabled_cpus_num = std::stoi(value); + } else if (arg == "--lora-model" && NextValue(argc, argv, i, value)) { + opt.lora_model_path = value; + } else if (arg == "--lora-name" && NextValue(argc, argv, i, value)) { + opt.lora_name = value; + } else if (arg == "--lora-scale" && NextValue(argc, argv, i, value)) { + opt.lora_scale = std::stof(value); + } else if (arg == "--prompt-cache-load" && NextValue(argc, argv, i, value)) { + opt.prompt_cache_load_path = value; + } else if (arg == "--prompt-cache-save" && NextValue(argc, argv, i, value)) { + opt.prompt_cache_save_path = value; + } else if (arg == "--system-prompt" && NextValue(argc, argv, i, value)) { + opt.system_prompt = value; + } else if (arg == "--chat-template-prefix" && NextValue(argc, argv, i, value)) { + opt.chat_template_prefix = value; + } else if (arg == "--chat-template-postfix" && NextValue(argc, argv, i, value)) { + opt.chat_template_postfix = value; + } else if (arg == "--stream") { + opt.stream = true; + } else if (arg == "--print-perf") { + opt.print_perf = true; + } else if (arg == "--keep-history") { + opt.keep_history = true; + } else if (arg == "--enable-thinking") { + opt.enable_thinking = true; + } else if (arg == "--no-embed-flash") { + opt.embed_flash = false; + } else if (arg == "--keep-special-tokens") { + opt.keep_special_tokens = true; + } else { + throw std::runtime_error("Unknown or incomplete argument: " + arg); + } + } + + if (opt.model_path.empty()) { + throw std::runtime_error("missing required argument --model"); + } + if (opt.role != "user" && opt.role != "tool") { + throw std::runtime_error("--role must be either 'user' or 'tool'"); + } + return opt; +} + +} // namespace + +int main(int argc, char** argv) { + try { + const Options options = ParseArgs(argc, argv); + RKLLMApp app(options); + g_app = &app; + std::signal(SIGINT, SignalHandler); + + app.Init(); + + if (!options.prompt.empty()) { + if (!options.stream) { + std::cout << "assistant: "; + } + const std::string answer = app.Generate(options.prompt, options.role, options.enable_thinking); + if (!options.stream) { + std::cout << answer << '\n'; + } + if (options.print_perf) { + app.PrintPerf(); + } + return 0; + } + + std::cout << "RKLLM interactive mode. Commands: exit, clear\n"; + while (true) { + std::cout << "\nuser: "; + std::string line; + if (!std::getline(std::cin, line)) { + break; + } + + if (line.empty()) { + continue; + } + if (line == "exit") { + break; + } + if (line == "clear") { + app.ClearKvCache(!options.system_prompt.empty()); + std::cout << "KV cache cleared.\n"; + continue; + } + + if (!options.stream) { + std::cout << "assistant: "; + } + const std::string answer = app.Generate(line, options.role, options.enable_thinking); + if (!options.stream) { + std::cout << answer << '\n'; + } + if (options.print_perf) { + app.PrintPerf(); + } + } + return 0; + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << "\n"; + return 1; + } +} diff --git a/examples/rkllm/inference.py b/examples/rkllm/inference.py new file mode 100644 index 0000000..aeb3cb8 --- /dev/null +++ b/examples/rkllm/inference.py @@ -0,0 +1,656 @@ +import argparse +import ctypes +import signal +import sys +import time +from pathlib import Path +from typing import List, Optional + + +LLMHandle = ctypes.c_void_p + +RKLLM_RUN_NORMAL = 0 +RKLLM_RUN_WAITING = 1 +RKLLM_RUN_FINISH = 2 +RKLLM_RUN_ERROR = 3 + +RKLLM_INPUT_PROMPT = 0 +RKLLM_INFER_GENERATE = 0 + +TARGET_PLATFORMS = ("rk3576", "rk3588", "rk3562", "rv1126b") + + +class RKLLMExtendParam(ctypes.Structure): + _fields_ = [ + ("base_domain_id", ctypes.c_int32), + ("embed_flash", ctypes.c_int8), + ("enabled_cpus_num", ctypes.c_int8), + ("enabled_cpus_mask", ctypes.c_uint32), + ("n_batch", ctypes.c_uint8), + ("use_cross_attn", ctypes.c_int8), + ("reserved", ctypes.c_uint8 * 104), + ] + + +class RKLLMParam(ctypes.Structure): + _fields_ = [ + ("model_path", ctypes.c_char_p), + ("max_context_len", ctypes.c_int32), + ("max_new_tokens", ctypes.c_int32), + ("top_k", ctypes.c_int32), + ("n_keep", ctypes.c_int32), + ("top_p", ctypes.c_float), + ("temperature", ctypes.c_float), + ("repeat_penalty", ctypes.c_float), + ("frequency_penalty", ctypes.c_float), + ("presence_penalty", ctypes.c_float), + ("mirostat", ctypes.c_int32), + ("mirostat_tau", ctypes.c_float), + ("mirostat_eta", ctypes.c_float), + ("skip_special_token", ctypes.c_bool), + ("is_async", ctypes.c_bool), + ("img_start", ctypes.c_char_p), + ("img_end", ctypes.c_char_p), + ("img_content", ctypes.c_char_p), + ("extend_param", RKLLMExtendParam), + ] + + +class RKLLMLoraAdapter(ctypes.Structure): + _fields_ = [ + ("lora_adapter_path", ctypes.c_char_p), + ("lora_adapter_name", ctypes.c_char_p), + ("scale", ctypes.c_float), + ] + + +class RKLLMEmbedInput(ctypes.Structure): + _fields_ = [ + ("embed", ctypes.POINTER(ctypes.c_float)), + ("n_tokens", ctypes.c_size_t), + ] + + +class RKLLMTokenInput(ctypes.Structure): + _fields_ = [ + ("input_ids", ctypes.POINTER(ctypes.c_int32)), + ("n_tokens", ctypes.c_size_t), + ] + + +class RKLLMMultiModalInput(ctypes.Structure): + _fields_ = [ + ("prompt", ctypes.c_char_p), + ("image_embed", ctypes.POINTER(ctypes.c_float)), + ("n_image_tokens", ctypes.c_size_t), + ("n_image", ctypes.c_size_t), + ("image_width", ctypes.c_size_t), + ("image_height", ctypes.c_size_t), + ] + + +class RKLLMInputUnion(ctypes.Union): + _fields_ = [ + ("prompt_input", ctypes.c_char_p), + ("embed_input", RKLLMEmbedInput), + ("token_input", RKLLMTokenInput), + ("multimodal_input", RKLLMMultiModalInput), + ] + + +class RKLLMInput(ctypes.Structure): + _fields_ = [ + ("role", ctypes.c_char_p), + ("enable_thinking", ctypes.c_bool), + ("input_type", ctypes.c_int), + ("input_data", RKLLMInputUnion), + ] + + +class RKLLMLoraParam(ctypes.Structure): + _fields_ = [("lora_adapter_name", ctypes.c_char_p)] + + +class RKLLMPromptCacheParam(ctypes.Structure): + _fields_ = [ + ("save_prompt_cache", ctypes.c_int), + ("prompt_cache_path", ctypes.c_char_p), + ] + + +class RKLLMInferParam(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("lora_params", ctypes.POINTER(RKLLMLoraParam)), + ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)), + ("keep_history", ctypes.c_int), + ] + + +class RKLLMResultLastHiddenLayer(ctypes.Structure): + _fields_ = [ + ("hidden_states", ctypes.POINTER(ctypes.c_float)), + ("embd_size", ctypes.c_int), + ("num_tokens", ctypes.c_int), + ] + + +class RKLLMResultLogits(ctypes.Structure): + _fields_ = [ + ("logits", ctypes.POINTER(ctypes.c_float)), + ("vocab_size", ctypes.c_int), + ("num_tokens", ctypes.c_int), + ] + + +class RKLLMPerfStat(ctypes.Structure): + _fields_ = [ + ("prefill_time_ms", ctypes.c_float), + ("prefill_tokens", ctypes.c_int), + ("generate_time_ms", ctypes.c_float), + ("generate_tokens", ctypes.c_int), + ("memory_usage_mb", ctypes.c_float), + ] + + +class RKLLMResult(ctypes.Structure): + _fields_ = [ + ("text", ctypes.c_char_p), + ("token_id", ctypes.c_int32), + ("last_hidden_layer", RKLLMResultLastHiddenLayer), + ("logits", RKLLMResultLogits), + ("perf", RKLLMPerfStat), + ] + + +LLMResultCallbackType = ctypes.CFUNCTYPE( + ctypes.c_int, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int +) + + +def cpu_mask_for_platform(platform: str, enabled_cpus_num: int) -> int: + if platform in {"rk3576", "rk3588"}: + max_cpus = 8 + elif platform in {"rk3562", "rv1126b"}: + max_cpus = 4 + else: + max_cpus = 8 + + if enabled_cpus_num <= 0 or enabled_cpus_num > max_cpus: + raise ValueError( + f"enabled-cpus-num must be in the range [1, {max_cpus}] for platform {platform}" + ) + + mask = 0 + for cpu in range(enabled_cpus_num): + mask |= 1 << cpu + return mask + + +def load_runtime_library(explicit_path: Optional[str]) -> ctypes.CDLL: + candidates = [] + if explicit_path: + candidates.append(Path(explicit_path)) + candidates.extend( + [ + Path("librkllmrt.so"), + Path("lib/librkllmrt.so"), + Path("rknn-llm/rkllm-runtime/Linux/librkllm_api/aarch64/librkllmrt.so"), + Path("rknn-llm/rkllm-runtime/Linux/librkllm_api/lib/librkllmrt.so"), + ] + ) + + for candidate in candidates: + if candidate.exists(): + return ctypes.CDLL(str(candidate.resolve())) + if explicit_path: + return ctypes.CDLL(explicit_path) + return ctypes.CDLL("librkllmrt.so") + + +class RKLLMRunner: + def __init__(self, args: argparse.Namespace): + self.args = args + self._text_chunks: List[str] = [] + self._last_perf: Optional[RKLLMPerfStat] = None + self._last_generation_wall_time_s: Optional[float] = None + self._lora_param_ref = None + self._prompt_cache_param_ref = None + self.handle = LLMHandle() + self.lib = load_runtime_library(args.runtime_lib) + self._bind_functions() + self._create_callback() + self._init_model() + if args.lora_model: + self._load_lora(args.lora_model, args.lora_adapter_name) + if args.prompt_cache_load: + self._load_prompt_cache(args.prompt_cache_load) + if args.chat_template_prefix or args.chat_template_postfix or args.system_prompt: + self._set_chat_template( + args.system_prompt or "", + args.chat_template_prefix or "", + args.chat_template_postfix or "", + ) + self.infer_param = self._build_infer_param() + + def _bind_functions(self) -> None: + self.lib.rkllm_createDefaultParam.argtypes = [] + self.lib.rkllm_createDefaultParam.restype = RKLLMParam + + self.lib.rkllm_init.argtypes = [ + ctypes.POINTER(LLMHandle), + ctypes.POINTER(RKLLMParam), + LLMResultCallbackType, + ] + self.lib.rkllm_init.restype = ctypes.c_int + + self.lib.rkllm_run.argtypes = [ + LLMHandle, + ctypes.POINTER(RKLLMInput), + ctypes.POINTER(RKLLMInferParam), + ctypes.c_void_p, + ] + self.lib.rkllm_run.restype = ctypes.c_int + + self.lib.rkllm_destroy.argtypes = [LLMHandle] + self.lib.rkllm_destroy.restype = ctypes.c_int + + self.lib.rkllm_clear_kv_cache.argtypes = [ + LLMHandle, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int), + ] + self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int + + self.lib.rkllm_abort.argtypes = [LLMHandle] + self.lib.rkllm_abort.restype = ctypes.c_int + + self.lib.rkllm_set_chat_template.argtypes = [ + LLMHandle, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_char_p, + ] + self.lib.rkllm_set_chat_template.restype = ctypes.c_int + + self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)] + self.lib.rkllm_load_lora.restype = ctypes.c_int + + self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p] + self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int + + def _create_callback(self) -> None: + def _callback(result_ptr, _userdata, state): + if state == RKLLM_RUN_FINISH: + result = result_ptr.contents + self._last_perf = result.perf + if self.args.stream: + sys.stdout.write("\n") + sys.stdout.flush() + return 0 + + if state == RKLLM_RUN_ERROR: + print("RKLLM runtime reported an inference error.", file=sys.stderr) + return 0 + + if state not in (RKLLM_RUN_NORMAL, RKLLM_RUN_WAITING): + return 0 + + result = result_ptr.contents + if not result.text: + return 0 + chunk = ctypes.string_at(result.text).decode("utf-8", errors="ignore") + if not chunk: + return 0 + self._text_chunks.append(chunk) + if self.args.stream: + sys.stdout.write(chunk) + sys.stdout.flush() + return 0 + + self._callback = LLMResultCallbackType(_callback) + + def _init_model(self) -> None: + param = self.lib.rkllm_createDefaultParam() + param.model_path = str(Path(self.args.model).resolve()).encode("utf-8") + param.max_new_tokens = self.args.max_new_tokens + param.max_context_len = self.args.max_context_len + param.top_k = self.args.top_k + param.top_p = self.args.top_p + param.temperature = self.args.temperature + param.repeat_penalty = self.args.repeat_penalty + param.frequency_penalty = self.args.frequency_penalty + param.presence_penalty = self.args.presence_penalty + param.skip_special_token = not self.args.keep_special_tokens + param.is_async = False + param.extend_param.base_domain_id = self.args.base_domain_id + param.extend_param.embed_flash = 1 if self.args.embed_flash else 0 + param.extend_param.n_batch = 1 + param.extend_param.use_cross_attn = 0 + param.extend_param.enabled_cpus_num = self.args.enabled_cpus_num + param.extend_param.enabled_cpus_mask = cpu_mask_for_platform( + self.args.target_platform, self.args.enabled_cpus_num + ) + + ret = self.lib.rkllm_init(ctypes.byref(self.handle), ctypes.byref(param), self._callback) + if ret != 0: + raise RuntimeError(f"rkllm_init failed with code {ret}") + + def _load_lora(self, lora_path: str, lora_name: str) -> None: + adapter = RKLLMLoraAdapter() + adapter.lora_adapter_path = str(Path(lora_path).resolve()).encode("utf-8") + adapter.lora_adapter_name = lora_name.encode("utf-8") + adapter.scale = self.args.lora_scale + ret = self.lib.rkllm_load_lora(self.handle, ctypes.byref(adapter)) + if ret != 0: + raise RuntimeError(f"rkllm_load_lora failed with code {ret}") + lora_param = RKLLMLoraParam() + lora_param.lora_adapter_name = lora_name.encode("utf-8") + self._lora_param_ref = lora_param + + def _load_prompt_cache(self, prompt_cache_path: str) -> None: + ret = self.lib.rkllm_load_prompt_cache( + self.handle, str(Path(prompt_cache_path).resolve()).encode("utf-8") + ) + if ret != 0: + raise RuntimeError(f"rkllm_load_prompt_cache failed with code {ret}") + + def _set_chat_template(self, system_prompt: str, prefix: str, postfix: str) -> None: + ret = self.lib.rkllm_set_chat_template( + self.handle, + system_prompt.encode("utf-8"), + prefix.encode("utf-8"), + postfix.encode("utf-8"), + ) + if ret != 0: + raise RuntimeError(f"rkllm_set_chat_template failed with code {ret}") + + def _build_infer_param(self) -> RKLLMInferParam: + infer_param = RKLLMInferParam() + ctypes.memset(ctypes.byref(infer_param), 0, ctypes.sizeof(RKLLMInferParam)) + infer_param.mode = RKLLM_INFER_GENERATE + infer_param.keep_history = 1 if self.args.keep_history else 0 + if self._lora_param_ref is not None: + infer_param.lora_params = ctypes.pointer(self._lora_param_ref) + else: + infer_param.lora_params = None + + if self.args.prompt_cache_save: + cache_param = RKLLMPromptCacheParam() + cache_param.save_prompt_cache = 1 + cache_param.prompt_cache_path = str(Path(self.args.prompt_cache_save).resolve()).encode( + "utf-8" + ) + self._prompt_cache_param_ref = cache_param + infer_param.prompt_cache_params = ctypes.pointer(self._prompt_cache_param_ref) + else: + infer_param.prompt_cache_params = None + return infer_param + + def generate(self, prompt: str, role: str = "user", enable_thinking: bool = False) -> str: + self._text_chunks = [] + self._last_perf = None + self._last_generation_wall_time_s = None + + rk_input = RKLLMInput() + ctypes.memset(ctypes.byref(rk_input), 0, ctypes.sizeof(RKLLMInput)) + role_bytes = role.encode("utf-8") + prompt_bytes = prompt.encode("utf-8") + rk_input.role = role_bytes + rk_input.enable_thinking = bool(enable_thinking) + rk_input.input_type = RKLLM_INPUT_PROMPT + rk_input.input_data.prompt_input = prompt_bytes + + started_at = time.perf_counter() + ret = self.lib.rkllm_run( + self.handle, + ctypes.byref(rk_input), + ctypes.byref(self.infer_param), + None, + ) + self._last_generation_wall_time_s = time.perf_counter() - started_at + if ret != 0: + raise RuntimeError(f"rkllm_run failed with code {ret}") + return "".join(self._text_chunks) + + def clear_kv_cache(self, keep_system_prompt: bool = True) -> None: + ret = self.lib.rkllm_clear_kv_cache( + self.handle, 1 if keep_system_prompt else 0, None, None + ) + if ret != 0: + raise RuntimeError(f"rkllm_clear_kv_cache failed with code {ret}") + + def abort(self) -> None: + self.lib.rkllm_abort(self.handle) + + def close(self) -> None: + if self.handle: + self.lib.rkllm_destroy(self.handle) + self.handle = LLMHandle() + + def print_perf(self) -> None: + if self._last_perf is None: + return + perf = self._last_perf + decode_tok_per_s = 0.0 + if perf.generate_time_ms > 0: + decode_tok_per_s = perf.generate_tokens / (perf.generate_time_ms / 1000.0) + actual_tok_per_s = 0.0 + if self._last_generation_wall_time_s and self._last_generation_wall_time_s > 0: + actual_tok_per_s = perf.generate_tokens / self._last_generation_wall_time_s + print( + "\n[perf] prefill={:.2f}ms/{} tok | decode={:.2f}ms/{} tok ({:.2f} tok/s) | actual={:.2f} tok/s | mem={:.2f} MB".format( + perf.prefill_time_ms, + perf.prefill_tokens, + perf.generate_time_ms, + perf.generate_tokens, + decode_tok_per_s, + actual_tok_per_s, + perf.memory_usage_mb, + ) + ) + + +def positive_int(text: str) -> int: + value = int(text) + if value <= 0: + raise argparse.ArgumentTypeError("value must be > 0") + return value + + +def positive_float(text: str) -> float: + value = float(text) + if value <= 0: + raise argparse.ArgumentTypeError("value must be > 0") + return value + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="RKLLM runtime inference script (SDK 1.2.3 style init/run/destroy flow)" + ) + parser.add_argument( + "-m", + "--model", + required=True, + help="Path to the .rkllm model file", + ) + parser.add_argument( + "--runtime-lib", + help="Path to librkllmrt.so (optional if discoverable via default paths or LD_LIBRARY_PATH)", + ) + parser.add_argument( + "--target-platform", + type=str.lower, + choices=TARGET_PLATFORMS, + default="rk3588", + help="Target Rockchip platform", + ) + parser.add_argument("--max-new-tokens", type=positive_int, default=512) + parser.add_argument("--max-context-len", type=positive_int, default=4096) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--top-p", type=float, default=0.95) + parser.add_argument("--temperature", type=float, default=0.8) + parser.add_argument("--repeat-penalty", type=positive_float, default=1.1) + parser.add_argument("--frequency-penalty", type=float, default=0.0) + parser.add_argument("--presence-penalty", type=float, default=0.0) + parser.add_argument("--base-domain-id", type=int, default=0) + parser.add_argument( + "--enabled-cpus-num", + type=int, + default=4, + help="Number of CPUs: rk3588/rk3576: 1-8, rk3562/rv1126b: 1-4", + ) + parser.add_argument( + "--embed-flash", + action="store_true", + default=True, + help="Enable embed_flash in RKLLMExtendParam", + ) + parser.add_argument( + "--no-embed-flash", + action="store_false", + dest="embed_flash", + help="Disable embed_flash in RKLLMExtendParam", + ) + parser.add_argument( + "--keep-special-tokens", + action="store_true", + help="Do not skip special tokens in output", + ) + parser.add_argument( + "--keep-history", + action="store_true", + help="Enable multi-turn history retention (keep_history=1)", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream generated tokens to stdout during inference", + ) + parser.add_argument( + "--print-perf", + action="store_true", + help="Print RKLLM perf stats at the end of each run", + ) + parser.add_argument( + "--prompt", + help="Single-shot prompt. If not set, run interactive mode.", + ) + parser.add_argument( + "--enable-thinking", + action="store_true", + help="Set enable_thinking=true in RKLLMInput (Qwen3 thinking mode)", + ) + parser.add_argument( + "--role", + default="user", + choices=("user", "tool"), + help="Role field for RKLLMInput", + ) + parser.add_argument("--lora-model", help="Path to LoRA model file") + parser.add_argument( + "--lora-adapter-name", + default="default_lora", + help="Name used to register and select LoRA adapter", + ) + parser.add_argument("--lora-scale", type=float, default=1.0) + parser.add_argument("--prompt-cache-load", help="Path of prompt cache file to preload") + parser.add_argument("--prompt-cache-save", help="Path to save prompt cache during inference") + parser.add_argument("--system-prompt", default="") + parser.add_argument("--chat-template-prefix", default="") + parser.add_argument("--chat-template-postfix", default="") + return parser + + +def validate_args(args: argparse.Namespace) -> None: + model_path = Path(args.model) + if not model_path.exists(): + raise SystemExit(f"Model file not found: {model_path}") + if args.lora_model and not Path(args.lora_model).exists(): + raise SystemExit(f"LoRA file not found: {args.lora_model}") + if args.prompt_cache_load and not Path(args.prompt_cache_load).exists(): + raise SystemExit(f"Prompt cache file not found: {args.prompt_cache_load}") + if args.target_platform in {"rk3576", "rk3588"}: + max_cpus = 8 + elif args.target_platform in {"rk3562", "rv1126b"}: + max_cpus = 4 + else: + max_cpus = 8 + + if args.enabled_cpus_num < 1 or args.enabled_cpus_num > max_cpus: + raise SystemExit( + f"--enabled-cpus-num must be in the range [1, {max_cpus}] for platform {args.target_platform}" + ) + + +def run_interactive(runner: RKLLMRunner, args: argparse.Namespace) -> None: + print("RKLLM interactive mode. Commands: 'exit', 'clear'") + while True: + try: + user_input = input("\nuser: ").strip() + except EOFError: + print() + break + if not user_input: + continue + if user_input.lower() == "exit": + break + if user_input.lower() == "clear": + runner.clear_kv_cache(keep_system_prompt=bool(args.system_prompt)) + print("KV cache cleared.") + continue + + if not args.stream: + print("assistant: ", end="", flush=True) + answer = runner.generate( + prompt=user_input, + role=args.role, + enable_thinking=args.enable_thinking, + ) + if not args.stream: + print(answer) + if args.print_perf: + runner.print_perf() + + +def main() -> None: + parser = build_arg_parser() + args = parser.parse_args() + validate_args(args) + + runner = None + + def _signal_handler(_sig, _frame): + nonlocal runner + if runner is not None: + runner.abort() + runner.close() + raise SystemExit(130) + + signal.signal(signal.SIGINT, _signal_handler) + + try: + runner = RKLLMRunner(args) + if args.prompt: + if not args.stream: + print("assistant: ", end="", flush=True) + answer = runner.generate( + prompt=args.prompt, + role=args.role, + enable_thinking=args.enable_thinking, + ) + if not args.stream: + print(answer) + if args.print_perf: + runner.print_perf() + else: + run_interactive(runner, args) + finally: + if runner is not None: + runner.close() + + +if __name__ == "__main__": + main() diff --git a/examples/rkllm/librkllmrt.so b/examples/rkllm/librkllmrt.so new file mode 100755 index 0000000..0084abc Binary files /dev/null and b/examples/rkllm/librkllmrt.so differ diff --git a/examples/rkllm/rkllm.h b/examples/rkllm/rkllm.h new file mode 100644 index 0000000..3678287 --- /dev/null +++ b/examples/rkllm/rkllm.h @@ -0,0 +1,409 @@ +#ifndef _RKLLM_H_ +#define _RKLLM_H_ +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define CPU0 (1 << 0) // 0x01 +#define CPU1 (1 << 1) // 0x02 +#define CPU2 (1 << 2) // 0x04 +#define CPU3 (1 << 3) // 0x08 +#define CPU4 (1 << 4) // 0x10 +#define CPU5 (1 << 5) // 0x20 +#define CPU6 (1 << 6) // 0x40 +#define CPU7 (1 << 7) // 0x80 + +/** + * @typedef LLMHandle + * @brief A handle used to manage and interact with the large language model. + */ +typedef void* LLMHandle; + +/** + * @enum LLMCallState + * @brief Describes the possible states of an LLM call. + */ +typedef enum { + RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */ + RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */ + RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */ + RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */ +} LLMCallState; + +/** + * @enum RKLLMInputType + * @brief Defines the types of inputs that can be fed into the LLM. + */ +typedef enum { + RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */ + RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */ + RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */ + RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */ +} RKLLMInputType; + +/** + * @enum RKLLMInferMode + * @brief Specifies the inference modes of the LLM. + */ +typedef enum { + RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */ + RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */ + RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */ +} RKLLMInferMode; + +/** + * @struct RKLLMExtendParam + * @brief The extend parameters for configuring an LLM instance. + */ +typedef struct { + int32_t base_domain_id; /**< base_domain_id */ + int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */ + int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */ + uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */ + uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */ + int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */ + uint8_t reserved[104]; /**< reserved */ +} RKLLMExtendParam; + +/** + * @struct RKLLMParam + * @brief Defines the parameters for configuring an LLM instance. + */ +typedef struct { + const char* model_path; /**< Path to the model file. */ + int32_t max_context_len; /**< Maximum number of tokens in the context window. */ + int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */ + int32_t top_k; /**< Top-K sampling parameter for token generation. */ + int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */ + float top_p; /**< Top-P (nucleus) sampling parameter. */ + float temperature; /**< Sampling temperature, affecting the randomness of token selection. */ + float repeat_penalty; /**< Penalty for repeating tokens in generation. */ + float frequency_penalty; /**< Penalizes frequent tokens during generation. */ + float presence_penalty; /**< Penalizes tokens based on their presence in the input. */ + int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */ + float mirostat_tau; /**< Tau parameter for Mirostat sampling. */ + float mirostat_eta; /**< Eta parameter for Mirostat sampling. */ + bool skip_special_token; /**< Whether to skip special tokens during generation. */ + bool is_async; /**< Whether to run inference asynchronously. */ + const char* img_start; /**< Starting position of an image in multimodal input. */ + const char* img_end; /**< Ending position of an image in multimodal input. */ + const char* img_content; /**< Pointer to the image content. */ + RKLLMExtendParam extend_param; /**< Extend parameters. */ +} RKLLMParam; + +/** + * @struct RKLLMLoraAdapter + * @brief Defines parameters for a Lora adapter used in model fine-tuning. + */ +typedef struct { + const char* lora_adapter_path; /**< Path to the Lora adapter file. */ + const char* lora_adapter_name; /**< Name of the Lora adapter. */ + float scale; /**< Scaling factor for applying the Lora adapter. */ +} RKLLMLoraAdapter; + +/** + * @struct RKLLMEmbedInput + * @brief Represents an embedding input to the LLM. + */ +typedef struct { + float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */ + size_t n_tokens; /**< Number of tokens represented in the embedding. */ +} RKLLMEmbedInput; + +/** + * @struct RKLLMTokenInput + * @brief Represents token input to the LLM. + */ +typedef struct { + int32_t* input_ids; /**< Array of token IDs. */ + size_t n_tokens; /**< Number of tokens in the input. */ +} RKLLMTokenInput; + +/** + * @struct RKLLMMultiModalInput + * @brief Represents multimodal input (e.g., text and image). + */ +typedef struct { + char* prompt; /**< Text prompt input. */ + float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */ + size_t n_image_tokens; /**< Number of image_token. */ + size_t n_image; /**< Number of image. */ + size_t image_width; /**< Width of image. */ + size_t image_height; /**< Height of image. */ +} RKLLMMultiModalInput; + +/** + * @struct RKLLMInput + * @brief Represents different types of input to the LLM via a union. + */ +typedef struct { + const char* role; /**< Message role: "user" (user input), "tool" (function result) */ + bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */ + RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */ + union { + const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */ + RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */ + RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */ + RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */ + }; +} RKLLMInput; + +/** + * @struct RKLLMLoraParam + * @brief Structure defining parameters for Lora adapters. + */ +typedef struct { + const char* lora_adapter_name; /**< Name of the Lora adapter. */ +} RKLLMLoraParam; + +/** + * @struct RKLLMPromptCacheParam + * @brief Structure to define parameters for caching prompts. + */ +typedef struct { + int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */ + const char* prompt_cache_path; /**< Path to the prompt cache file. */ +} RKLLMPromptCacheParam; + +/** + * @struct RKLLMCrossAttnParam + * @brief Structure holding parameters for cross-attention inference. + * + * This structure is used when performing cross-attention in the decoder. + * It provides the encoder output (key/value caches), position indices, + * and attention mask. + * + * - `encoder_k_cache` must be stored in contiguous memory with layout: + * [num_layers][num_tokens][num_kv_heads][head_dim] + * - `encoder_v_cache` must be stored in contiguous memory with layout: + * [num_layers][num_kv_heads][head_dim][num_tokens] + */ +typedef struct { + float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */ + float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */ + float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */ + int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */ + int num_tokens; /**< Number of tokens in the encoder sequence. */ +} RKLLMCrossAttnParam; + +/** + * @struct RKLLMInferParam + * @brief Structure for defining parameters during inference. + */ +typedef struct { + RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */ + RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */ + RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */ + int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/ +} RKLLMInferParam; + +/** + * @struct RKLLMResultLastHiddenLayer + * @brief Structure to hold the hidden states from the last layer. + */ +typedef struct { + const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */ + int embd_size; /**< Size of the embedding vector. */ + int num_tokens; /**< Number of tokens for which hidden states are stored. */ +} RKLLMResultLastHiddenLayer; + +/** + * @struct RKLLMResultLogits + * @brief Structure to hold the logits. + */ +typedef struct { + const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */ + int vocab_size; /**< Size of the vocab. */ + int num_tokens; /**< Number of tokens for which logits are stored. */ +} RKLLMResultLogits; + +/** + * @struct RKLLMPerfStat + * @brief Structure to hold performance statistics for prefill and generate stages. + */ +typedef struct { + float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */ + int prefill_tokens; /**< Number of tokens processed during the prefill stage. */ + float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */ + int generate_tokens; /**< Number of tokens processed during the generate stage. */ + float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */ +} RKLLMPerfStat; + +/** + * @struct RKLLMResult + * @brief Structure to represent the result of LLM inference. + */ +typedef struct { + const char* text; /**< Generated text result. */ + int32_t token_id; /**< ID of the generated token. */ + RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */ + RKLLMResultLogits logits; /**< Model output logits. */ + RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */ +} RKLLMResult; + +/** + * @typedef LLMResultCallback + * @brief Callback function to handle LLM results. + * @param result Pointer to the LLM result. + * @param userdata Pointer to user data for the callback. + * @param state State of the LLM call (e.g., finished, error). + * @return int Return value indicating the handling status: + * - 0: Continue inference normally. + * - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt), + * return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference. + */ +typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state); + +/** + * @brief Creates a default RKLLMParam structure with preset values. + * @return A default RKLLMParam structure. + */ +RKLLMParam rkllm_createDefaultParam(); + +/** + * @brief Initializes the LLM with the given parameters. + * @param handle Pointer to the LLM handle. + * @param param Configuration parameters for the LLM. + * @param callback Callback function to handle LLM results. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback); + +/** + * @brief Loads a Lora adapter into the LLM. + * @param handle LLM handle. + * @param lora_adapter Pointer to the Lora adapter structure. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter); + +/** + * @brief Loads a prompt cache from a file. + * @param handle LLM handle. + * @param prompt_cache_path Path to the prompt cache file. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path); + +/** + * @brief Releases the prompt cache from memory. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_release_prompt_cache(LLMHandle handle); + +/** + * @brief Destroys the LLM instance and releases resources. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_destroy(LLMHandle handle); + +/** + * @brief Runs an LLM inference task synchronously. + * @param handle LLM handle. + * @param rkllm_input Input data for the LLM. + * @param rkllm_infer_params Parameters for the inference task. + * @param userdata Pointer to user data for the callback. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); + +/** + * @brief Runs an LLM inference task asynchronously. + * @param handle LLM handle. + * @param rkllm_input Input data for the LLM. + * @param rkllm_infer_params Parameters for the inference task. + * @param userdata Pointer to user data for the callback. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); + +/** + * @brief Aborts an ongoing LLM task. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_abort(LLMHandle handle); + +/** + * @brief Checks if an LLM task is currently running. + * @param handle LLM handle. + * @return Status code (0 if a task is running, non-zero for otherwise). + */ +int rkllm_is_running(LLMHandle handle); + +/** + * @brief Clear the key-value cache for a given LLM handle. + * + * This function is used to clear part or all of the KV cache. + * + * @param handle LLM handle. + * @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear). + * This flag is ignored if a specific range [start_pos, end_pos) is provided. + * @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch. + * @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch. + * If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect, + * If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored. + * @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback + * @return Status code (0 if cache was cleared successfully, non-zero otherwise). + */ +int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos); + +/** + * @brief Get the current size of the key-value cache for a given LLM handle. + * + * This function returns the total number of positions currently stored in the model's KV cache. + * + * @param handle LLM handle. + * @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored. + * The array must be preallocated with space for `n_batch` elements. + */ +int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes); + +/** + * @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix. + * + * This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix. + * The system prompt is typically used to define the behavior or context of the language model, + * while the prefix and postfix are used to format the user input and output respectively. + * + * @param handle LLM handle. + * @param system_prompt The system prompt that defines the context or behavior of the language model. + * @param prompt_prefix The prefix added before the user input in the chat. + * @param prompt_postfix The postfix added after the user input in the chat. + * + * @return Status code (0 if the template was set successfully, non-zero for errors). + */ +int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix); + +/** + * @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token. + * + * @param handle LLM handle. + * @param system_prompt The system prompt that defines the context or behavior of the language model. + * @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters. + * @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag, + * allowing tokenizer to recognize tool outputs separately from normal dialogue turns. + * @return Status code (0 if the configuration was set successfully, non-zero for errors). + */ +int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str); + +/** + * @brief Sets the cross-attention parameters for the LLM decoder. + * + * @param handle LLM handle. + * @param cross_attn_params Pointer to the structure containing encoder-related input data + * used for cross-attention (see RKLLMCrossAttnParam for details). + * + * @return Status code (0 if the parameters were set successfully, non-zero for errors). + */ +int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params); + +#ifdef __cplusplus +} +#endif + +#endif