diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3dc6f6941..0741d130c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -754,6 +754,7 @@ jobs: cpp_ext_test: name: C++ Extension Test (${{ matrix.os }}) needs: [ clone ] + if: always() runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/.github/workflows/publish_release.yml b/.github/workflows/publish_release.yml index 5d23326c4..ae7ecfb86 100644 --- a/.github/workflows/publish_release.yml +++ b/.github/workflows/publish_release.yml @@ -191,12 +191,6 @@ jobs: name: repo-with-docs path: ./repo_artifact - - name: Install Python dev headers (Ubuntu only) - if: startsWith(matrix.os, 'ubuntu') - run: | - sudo apt-get update - sudo apt-get install -y python3-dev - - name: Extract repo-with-docs run: | set -ex diff --git a/.gitignore b/.gitignore index 0e34cdd67..c9b8793e2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ test/ dist/ tmp/ build +.cache/ *.lock *.db mkdocs.yml @@ -64,3 +65,4 @@ docs/zh/assets build* lazyllm_cpp.egg-info/ !build*.sh +lazyllm/cpp_lib/ \ No newline at end of file diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 055fd5cbf..4bb12f105 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,23 +1,77 @@ cmake_minimum_required(VERSION 3.16) project(LazyLLMCPP LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED) -find_package(pybind11 CONFIG REQUIRED) +# Third party libs +include(cmake/third_party.cmake) # Config lazyllm_core lib with pure cpp code. -file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") +file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/core/src/*.cpp") add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) -target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/include) +target_link_libraries(lazyllm_core PUBLIC xxhash) +target_link_libraries(lazyllm_core PUBLIC tiktoken) +target_link_libraries(lazyllm_core PUBLIC utf8proc) +target_compile_options(lazyllm_core PRIVATE -Werror -Wshadow) + +# Config lazyllm_adaptor lib which maintains callback invocations. +file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/adaptor/*.cpp") +add_library(lazyllm_adaptor STATIC ${LAZYLLM_ADAPTOR_SOURCES}) +target_include_directories(lazyllm_adaptor PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/adaptor) +target_link_libraries(lazyllm_adaptor PUBLIC pybind11::headers Python3::Python lazyllm_core) +target_compile_options(lazyllm_adaptor PRIVATE -Werror -Wshadow) # Config lazyllm_cpp lib with binding infomations. -set(LAZYLLM_BINDING_SOURCES binding/lazyllm.cpp binding/doc.cpp) +file(GLOB_RECURSE LAZYLLM_BINDING_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/binding/*.cpp") set(INTERFACE_TARGET_NAME lazyllm_cpp) pybind11_add_module(${INTERFACE_TARGET_NAME} ${LAZYLLM_BINDING_SOURCES}) -target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core) +target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core lazyllm_adaptor) +target_compile_options(${INTERFACE_TARGET_NAME} PRIVATE -Werror -Wshadow) + +# Runtime loader configuration per platform. +set(_lazyllm_cpp_rpath "") +set(LAZYLLM_TEST_RUNTIME_ENV "" CACHE INTERNAL "Runtime env for LazyLLM C++ tests" FORCE) +if (WIN32) + # Windows has no ELF rpath; loader resolution is driven by PATH and DLL search order. + # Keep test runtime env empty by default. +elseif (APPLE) + # Ensure lazyllm_cpp can find third-party dylibs under lazyllm/cpp_lib. + list(APPEND _lazyllm_cpp_rpath "@loader_path/cpp_lib") +else () + # Ensure lazyllm_cpp can find third-party shared libraries under lazyllm/cpp_lib. + list(APPEND _lazyllm_cpp_rpath "$ORIGIN/cpp_lib") + # Use DT_RPATH (instead of DT_RUNPATH) so the extension's own runtime + # search path can take precedence over host interpreter bundled libs. + target_link_options(${INTERFACE_TARGET_NAME} PRIVATE -Wl,--disable-new-dtags) + + # Resolve libstdc++ from the active C++ compiler and include it in rpath. + execute_process( + COMMAND ${CMAKE_CXX_COMPILER} -print-file-name=libstdc++.so.6 + OUTPUT_VARIABLE LIBSTDCPP_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if (LIBSTDCPP_PATH AND NOT LIBSTDCPP_PATH STREQUAL "libstdc++.so.6") + get_filename_component(LIBSTDCPP_DIR "${LIBSTDCPP_PATH}" DIRECTORY) + if (LIBSTDCPP_DIR) + list(APPEND _lazyllm_cpp_rpath "${LIBSTDCPP_DIR}") + set(LAZYLLM_TEST_RUNTIME_ENV "LD_LIBRARY_PATH=${LIBSTDCPP_DIR}:$ENV{LD_LIBRARY_PATH}" + CACHE INTERNAL "Runtime env for LazyLLM C++ tests" FORCE) + endif () + endif () +endif () + +if (_lazyllm_cpp_rpath) + set_target_properties(${INTERFACE_TARGET_NAME} PROPERTIES + BUILD_RPATH "${_lazyllm_cpp_rpath}" + INSTALL_RPATH "${_lazyllm_cpp_rpath}" + ) +endif () if (CMAKE_BUILD_TYPE STREQUAL "Debug") # SHOW_SYMBOL @@ -26,7 +80,14 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") endif() # Install -install(TARGETS ${INTERFACE_TARGET_NAME} LIBRARY DESTINATION lazyllm) +install(TARGETS ${INTERFACE_TARGET_NAME} + LIBRARY DESTINATION lazyllm COMPONENT lazyllm_cpp + RUNTIME DESTINATION lazyllm COMPONENT lazyllm_cpp +) +install(TARGETS tiktoken utf8proc + LIBRARY DESTINATION lazyllm/cpp_lib COMPONENT lazyllm_cpp + RUNTIME DESTINATION lazyllm/cpp_lib COMPONENT lazyllm_cpp +) # TESTS diff --git a/csrc/include/README.md b/csrc/README.md similarity index 100% rename from csrc/include/README.md rename to csrc/README.md diff --git a/csrc/adaptor/adaptor.cpp b/csrc/adaptor/adaptor.cpp new file mode 100644 index 000000000..cd57a3bbe --- /dev/null +++ b/csrc/adaptor/adaptor.cpp @@ -0,0 +1,2 @@ +#include "adaptor_base_wrapper.hpp" +#include "document_store.hpp" diff --git a/csrc/adaptor/adaptor_base_wrapper.hpp b/csrc/adaptor/adaptor_base_wrapper.hpp new file mode 100644 index 000000000..bafd78a60 --- /dev/null +++ b/csrc/adaptor/adaptor_base_wrapper.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "adaptor_base.hpp" + + +namespace lazyllm { + +class LAZYLLM_HIDDEN AdaptorBaseWrapper : public AdaptorBase { + pybind11::object _py_obj; +public: + AdaptorBaseWrapper(const pybind11::object &obj) : _py_obj(obj) {} + virtual ~AdaptorBaseWrapper() = default; + + std::any call( + const std::string& func_name, + const std::unordered_map& args) const override final + { + pybind11::gil_scoped_acquire gil; + pybind11::object func = pybind11::getattr(_py_obj, func_name.c_str(), pybind11::none()); + return call_impl(func_name, func, args); + } + + virtual std::any call_impl( + const std::string& func_name, + const pybind11::object& func, + const std::unordered_map& args) const = 0; +}; + +} diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp new file mode 100644 index 000000000..1be72d066 --- /dev/null +++ b/csrc/adaptor/document_store.hpp @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "adaptor_base_wrapper.hpp" +#include "doc_node.hpp" + +namespace lazyllm { + +struct NodeGroup { + enum class Type { + ORIGINAL, CHUNK, SUMMARY, IMAGE_INFO, QUESTION_ANSWER, OTHER + }; + std::string _parent; + std::string _display_name; + Type _type; + NodeGroup( + const std::string& parent, + const std::string& display_name, + const Type& type = Type::ORIGINAL) : + _parent(parent), _display_name(display_name), _type(type) {} +}; + +class LAZYLLM_HIDDEN DocumentStore : public AdaptorBaseWrapper { +public: + DocumentStore() = delete; + explicit DocumentStore( + const pybind11::object& store, + const std::unordered_map &map) : + AdaptorBaseWrapper(store), _node_groups_map(map) {} + + // Cache-aware factory to avoid rebuilding adaptor for the same Python store. + static std::shared_ptr from_store( + const pybind11::object& store, const std::unordered_map& map) { + if (store.is_none()) return nullptr; + + pybind11::gil_scoped_acquire gil; + PyObject *key = store.ptr(); + auto &cache = store_cache(); + auto it = cache.find(key); + if (it != cache.end()) { + if (auto existing = it->second.lock()) + return existing; + } + auto created = std::make_shared(store, map); + cache[key] = created; + return created; + } + + DocNode::Children get_node_children(const DocNode* node) const { + DocNode::Children out; + auto& kb_id = std::get(node->_p_global_metadata->at(std::string(RAGMetadataKeys::KB_ID))); + auto& doc_id = std::get(node->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_ID))); + auto& group_name = node->_group_name; + for(auto& [current_group_name, group] : _node_groups_map) { + if (group._parent != group_name) continue; + if (!std::any_cast(call("is_group_active", {{"group", current_group_name}}))) continue; + auto nodes_in_group = std::any_cast>(call("get_nodes", { + {"group_name", current_group_name}, + {"kb_id", kb_id}, + {"doc_ids", std::vector({doc_id})} + })); + + std::vector children; + children.reserve(nodes_in_group.size()); + for (auto n : nodes_in_group) + if (n->get_parent_node() == node) children.push_back(n); + out[current_group_name] = children; + } + return out; + } + +private: + std::unordered_map _node_groups_map; + + std::any call_impl( + const std::string& func_name, + const pybind11::object& func, + const std::unordered_map& args) const override + { + if (func_name == "is_group_active") { + return func(args.at("group")).cast(); + } + else if (func_name == "get_node") { + return func( + pybind11::arg("group_name") = std::any_cast(args.at("group_name")), + pybind11::arg("uids") = std::vector({std::any_cast(args.at("uid"))}), + pybind11::arg("kb_id") = std::any_cast(args.at("kb_id")), + pybind11::arg("display") = true + ).cast()[0].cast(); + } + else if (func_name == "get_nodes") { + return func( + pybind11::arg("group_name") = std::any_cast(args.at("group_name")), + pybind11::arg("kb_id") = std::any_cast(args.at("kb_id")), + pybind11::arg("doc_ids") = std::vector({std::any_cast(args.at("doc_id"))}) + ).cast>(); + } + else if (func_name == "get_node_children") { + return get_node_children(std::any_cast(args.at("node"))); + } + + throw std::runtime_error("Unknown DocumentStore function: " + func_name); + } + + // Cache by Python object identity to ensure one wrapper per store instance. + static std::unordered_map> &store_cache() { + static std::unordered_map> cache; + return cache; + } +}; + +} // namespace lazyllm diff --git a/csrc/binding/binding_utils.cpp b/csrc/binding/binding_utils.cpp new file mode 100644 index 000000000..cf5d82d67 --- /dev/null +++ b/csrc/binding/binding_utils.cpp @@ -0,0 +1,156 @@ +#include "binding_utils.hpp" + +#include + +namespace lazyllm::pybind_utils { + +std::string DumpJson(const py::object& obj) { + py::object json = py::module_::import("json"); + py::object dumps = json.attr("dumps"); + py::object dumped = dumps(obj, py::arg("ensure_ascii") = false); + return dumped.cast(); +} + +py::object LoadJson(const std::string& text) { + py::object json = py::module_::import("json"); + py::object loads = json.attr("loads"); + return loads(py::str(text)); +} + +bool ExtractStringSequence(const py::object& obj, std::vector* out) { + if (!py::isinstance(obj) || py::isinstance(obj)) return false; + py::sequence seq = obj.cast(); + out->clear(); + out->reserve(seq.size()); + for (py::handle item : seq) { + if (!py::isinstance(item)) { + out->clear(); + return false; + } + out->push_back(py::cast(item)); + } + return true; +} + +lazyllm::MetadataMode ParseMetadataMode(const py::object& mode) { + if (mode.is_none()) return lazyllm::MetadataMode::NONE; + try { + if (py::hasattr(mode, "name")) { + const auto name = py::cast(mode.attr("name")); + if (name == "ALL") return lazyllm::MetadataMode::ALL; + if (name == "EMBED") return lazyllm::MetadataMode::EMBED; + if (name == "LLM") return lazyllm::MetadataMode::LLM; + if (name == "NONE") return lazyllm::MetadataMode::NONE; + } + } catch (const py::error_already_set&) { + } + if (py::isinstance(mode)) { + const auto name = mode.cast(); + if (name == "ALL") return lazyllm::MetadataMode::ALL; + if (name == "EMBED") return lazyllm::MetadataMode::EMBED; + if (name == "LLM") return lazyllm::MetadataMode::LLM; + if (name == "NONE") return lazyllm::MetadataMode::NONE; + } + if (py::isinstance(mode)) { + const auto value = mode.cast(); + switch (value) { + case 0: return lazyllm::MetadataMode::ALL; + case 1: return lazyllm::MetadataMode::EMBED; + case 2: return lazyllm::MetadataMode::LLM; + case 3: return lazyllm::MetadataMode::NONE; + default: break; + } + } + return lazyllm::MetadataMode::NONE; +} + +lazyllm::DocNode::MetadataVType PyToMetadataValue(const py::handle& value) { + if (value.is_none()) return std::any(py::none()); + if (py::isinstance(value)) return static_cast(value.cast()); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + + if (py::isinstance(value) && !py::isinstance(value)) { + py::sequence seq = value.cast(); + if (seq.empty()) return std::any(py::reinterpret_borrow(value)); + + bool all_str = true; + bool all_int = true; + bool all_numeric = true; + + for (py::handle item : seq) { + const bool is_str = py::isinstance(item); + const bool is_int = py::isinstance(item) && !py::isinstance(item); + const bool is_numeric = is_int || py::isinstance(item) || py::isinstance(item); + all_str = all_str && is_str; + all_int = all_int && is_int; + all_numeric = all_numeric && is_numeric; + } + + if (all_str) { + std::vector out; + out.reserve(seq.size()); + for (py::handle item : seq) out.push_back(py::cast(item)); + return out; + } + if (all_int) { + std::vector out; + out.reserve(seq.size()); + for (py::handle item : seq) out.push_back(py::cast(item)); + return out; + } + if (all_numeric) { + std::vector out; + out.reserve(seq.size()); + for (py::handle item : seq) out.push_back(py::cast(item)); + return out; + } + } + return std::any(py::reinterpret_borrow(value)); +} + +py::object MetadataValueToPy(const lazyllm::DocNode::MetadataVType& value) { + return std::visit([](const auto& v) -> py::object { + using T = std::decay_t; + if constexpr (std::is_same_v) return py::str(v); + if constexpr (std::is_same_v) return py::int_(v); + if constexpr (std::is_same_v) return py::float_(v); + if constexpr (std::is_same_v>) return py::cast(v); + if constexpr (std::is_same_v>) return py::cast(v); + if constexpr (std::is_same_v>) return py::cast(v); + if constexpr (std::is_same_v) return AnyToPy(v); + return py::none(); + }, value); +} + +std::any PyToAny(const py::handle& value) { + if (value.is_none()) return py::none(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + return py::reinterpret_borrow(value); +} + +py::object AnyToPy(const std::any& value) { + const auto& t = value.type(); + if (!value.has_value()) return py::none(); + if (t == typeid(py::object)) return std::any_cast(value); + if (t == typeid(py::none)) return py::none(); + if (t == typeid(std::string)) return py::str(std::any_cast(value)); + if (t == typeid(const char*)) return py::str(std::any_cast(value)); + if (t == typeid(char*)) return py::str(std::any_cast(value)); + if (t == typeid(bool)) return py::bool_(std::any_cast(value)); + if (t == typeid(int)) return py::int_(std::any_cast(value)); + if (t == typeid(long)) return py::int_(std::any_cast(value)); + if (t == typeid(long long)) return py::int_(std::any_cast(value)); + if (t == typeid(unsigned int)) return py::int_(std::any_cast(value)); + if (t == typeid(unsigned long)) return py::int_(std::any_cast(value)); + if (t == typeid(unsigned long long)) return py::int_(std::any_cast(value)); + if (t == typeid(float)) return py::float_(std::any_cast(value)); + if (t == typeid(double)) return py::float_(std::any_cast(value)); + return py::str(""); +} + +} // namespace lazyllm::pybind_utils diff --git a/csrc/binding/binding_utils.hpp b/csrc/binding/binding_utils.hpp new file mode 100644 index 000000000..c1d8252ae --- /dev/null +++ b/csrc/binding/binding_utils.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include +#include + +#include "lazyllm.hpp" +#include "doc_node.hpp" + +namespace lazyllm::pybind_utils { + +std::string DumpJson(const py::object& obj); +py::object LoadJson(const std::string& text); +bool ExtractStringSequence(const py::object& obj, std::vector* out); +lazyllm::MetadataMode ParseMetadataMode(const py::object& mode); +lazyllm::DocNode::MetadataVType PyToMetadataValue(const py::handle& value); +py::object MetadataValueToPy(const lazyllm::DocNode::MetadataVType& value); +std::any PyToAny(const py::handle& value); +py::object AnyToPy(const std::any& value); + +} // namespace lazyllm::pybind_utils + +namespace pybind11::detail { + +template <> +struct type_caster { +public: + PYBIND11_TYPE_CASTER(lazyllm::MetadataVType, _("MetadataVType")); + + bool load(handle src, bool) { + try { + value = lazyllm::pybind_utils::PyToMetadataValue(src); + return true; + } catch (const pybind11::error_already_set&) { + PyErr_Clear(); + return false; + } catch (...) { + return false; + } + } + + static handle cast(const lazyllm::MetadataVType& src, return_value_policy, handle) { + pybind11::object obj = lazyllm::pybind_utils::MetadataValueToPy(src); + return obj.release(); + } +}; + +} // namespace pybind11::detail diff --git a/csrc/binding/doc.cpp b/csrc/binding/export_add_doc_str.cpp similarity index 97% rename from csrc/binding/doc.cpp rename to csrc/binding/export_add_doc_str.cpp index 587315a2a..4ef0fd53b 100644 --- a/csrc/binding/doc.cpp +++ b/csrc/binding/export_add_doc_str.cpp @@ -28,6 +28,6 @@ void addDocStr(py::object obj, std::string docs) { } } -void exportDoc(py::module& m) { +void exportAddDocStr(py::module& m) { m.def("add_doc", &addDocStr, "Add docstring to a function or method", py::arg("obj"), py::arg("docs")); } diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp new file mode 100644 index 000000000..f43abb427 --- /dev/null +++ b/csrc/binding/export_doc_node.cpp @@ -0,0 +1,694 @@ +#include +#include +#include +#include +#include +#include + +#include "lazyllm.hpp" +#include "document_store.hpp" +#include "doc_node.hpp" +#include "binding_utils.hpp" +#include + +PYBIND11_MAKE_OPAQUE(lazyllm::DocNode::Metadata); +PYBIND11_MAKE_OPAQUE(lazyllm::DocNode::Children); + +namespace { + +namespace pyu = lazyllm::pybind_utils; + +const std::unordered_map kDocNodeAttrAliases = { + {"uid", "_uid"}, + {"group", "_group"}, + {"content", "_content"}, + {"parent", "_parent"}, + {"children", "_children"}, + {"global_metadata", "_global_metadata"}, + {"metadata", "_metadata"}, + {"excluded_embed_metadata_keys", "_excluded_embed_metadata_keys"}, + {"excluded_llm_metadata_keys", "_excluded_llm_metadata_keys"}, +}; + +const std::unordered_set kDocNodeReadonlyAliases = { + "uid", + "group", +}; + +const char* ResolveDocNodeAlias(const std::string& attr_name) { + const auto it = kDocNodeAttrAliases.find(attr_name); + if (it == kDocNodeAttrAliases.end()) return nullptr; + return it->second; +} + +bool IsJsonDocNode(const py::object& self) { + try { + const auto name = py::cast(py::type::of(self).attr("__name__")); + return name == "JsonDocNode"; + } catch (const py::error_already_set&) { + return false; + } +} + +lazyllm::DocNode::Metadata MetadataFromPy(const py::object& obj) { + lazyllm::DocNode::Metadata out; + if (obj.is_none()) return out; + py::dict d = py::dict(obj); + out.reserve(d.size()); + for (auto item : d) { + const std::string key = py::cast(item.first); + out.emplace(key, pyu::PyToMetadataValue(item.second)); + } + return out; +} + +py::dict MetadataToPy(const lazyllm::DocNode::Metadata& meta) { + py::dict d; + for (const auto& [key, value] : meta) { + d[py::str(key)] = pyu::MetadataValueToPy(value); + } + return d; +} + +lazyllm::DocNode::Metadata& EnsureRootGlobalMetadata(lazyllm::DocNode& node) { + auto* root = const_cast(node.get_root_node()); + if (!root->_p_global_metadata) { + root->_p_global_metadata = std::make_shared(); + } + return *(root->_p_global_metadata); +} + +const lazyllm::DocNode::Metadata& GetRootGlobalMetadata(const lazyllm::DocNode& node) { + static const lazyllm::DocNode::Metadata empty; + auto* root = node.get_root_node(); + if (!root || !root->_p_global_metadata) return empty; + return *(root->_p_global_metadata); +} + +lazyllm::DocNode::Children ChildrenFromPy(const py::object& obj) { + lazyllm::DocNode::Children out; + if (obj.is_none()) return out; + py::dict d = py::dict(obj); + out.reserve(d.size()); + for (auto item : d) { + const std::string group = py::cast(item.first); + py::object nodes_obj = py::reinterpret_borrow(item.second); + std::vector nodes; + for (auto child_obj : nodes_obj) { + py::object child = py::reinterpret_borrow(child_obj); + if (child.is_none()) continue; + nodes.emplace_back(child.cast()); + } + out.emplace(group, std::move(nodes)); + } + return out; +} + +py::dict ChildrenToPy(const lazyllm::DocNode::Children& children) { + py::dict d; + for (const auto& [group, nodes] : children) { + py::list py_nodes; + for (const auto& n : nodes) py_nodes.append(n); + d[py::str(group)] = std::move(py_nodes); + } + return d; +} + +std::vector NodeVectorFromPy(const py::object& obj) { + std::vector out; + if (obj.is_none()) return out; + for (auto item : obj) { + py::object node = py::reinterpret_borrow(item); + if (node.is_none()) continue; + out.emplace_back(node.cast()); + } + return out; +} + +using NodeGroups = std::unordered_map>; + +std::optional NodeGroupsFromPy(const py::object& obj) { + if (obj.is_none()) return std::nullopt; + py::dict d = py::dict(obj); + NodeGroups out; + out.reserve(d.size()); + for (auto item : d) { + const std::string key = py::cast(item.first); + py::object group_obj = py::reinterpret_borrow(item.second); + py::dict group_dict = py::dict(group_obj); + std::unordered_map inner; + inner.reserve(group_dict.size()); + for (auto kv : group_dict) { + const std::string inner_key = py::cast(kv.first); + inner.emplace(inner_key, pyu::PyToAny(kv.second)); + } + out.emplace(key, std::move(inner)); + } + return out; +} + +std::optional>> NormalizeContent( + const py::object& content +) { + if (content.is_none()) return std::nullopt; + if (py::isinstance(content)) return content.cast(); + std::vector texts; + if (pyu::ExtractStringSequence(content, &texts)) return texts; + return pyu::DumpJson(content); +} + +lazyllm::DocNode init( + std::optional uid, + py::object content, + std::optional group, + std::optional embedding, + py::object parent, + py::object store, + py::object node_groups, + py::object metadata, + py::object global_metadata, + py::object text +) { + if (!content.is_none() && !text.is_none()) + throw std::invalid_argument("`text` and `content` cannot be set at the same time."); + + lazyllm::DocNode* p_parent_node = nullptr; + std::shared_ptr store_adaptor = nullptr; + + // Build node groups map. + // Usually, parent + store + node_groups are not None at the same time. + const auto node_groups_opt = NodeGroupsFromPy(node_groups); + const auto metadata_map = MetadataFromPy(metadata); + const auto global_metadata_map = MetadataFromPy(global_metadata); + const bool has_parent = !parent.is_none(); + const bool has_store_context = has_parent && !store.is_none() + && node_groups_opt.has_value() && !global_metadata.is_none() && group.has_value(); + if (has_store_context) { + std::unordered_map node_groups_map; + node_groups_map.reserve(node_groups_opt->size()); + for (const auto& [group_key, group_dict] : *node_groups_opt) { + node_groups_map.emplace(group_key, lazyllm::NodeGroup( + std::any_cast(group_dict.at(std::string("parent"))), + std::any_cast(group_dict.at(std::string("display_name"))) + )); + } + store_adaptor = lazyllm::DocumentStore::from_store(store, node_groups_map); + + auto kb_id = std::get(global_metadata_map.at( + std::string(lazyllm::RAGMetadataKeys::KB_ID))); + auto doc_id = std::get(global_metadata_map.at( + std::string(lazyllm::RAGMetadataKeys::DOC_ID))); + + if (py::isinstance(parent)) { + const auto parent_uid = parent.cast(); + p_parent_node = std::any_cast(store_adaptor->call("get_node", + {{"group_name", *group}, {"uid", parent_uid}, {"kb_id", kb_id}})); + } else { + p_parent_node = parent.cast(); + } + } else if (has_parent && !py::isinstance(parent)) { + p_parent_node = parent.cast(); + } + + std::string raw_text; + + lazyllm::DocNode node( + "", + group.value_or(""), + uid.value_or(""), + p_parent_node, + metadata_map, + std::make_shared(global_metadata_map) + ); + if (store_adaptor) node.set_store(store_adaptor); + if (embedding) { + for (const auto& [key, vec] : *embedding) + node.set_embedding_vec(key, vec); + } + if (!content.is_none()) { + const auto normalized = NormalizeContent(content); + if (normalized) { + if (const auto* s = std::get_if(&*normalized)) + node.set_root_text(std::move(*s)); + else + node.set_root_texts(std::get>(*normalized)); + } + } + else if (!text.is_none()){ + node.set_root_text(text.cast()); + } + + return node; +} + +std::string DocNodeToString(const lazyllm::DocNode& node) { + py::dict d; + const auto children = node.get_children(); + for (const auto& [group, nodes] : children) { + py::list ids; + for (std::shared_ptr n : nodes) { + if (n) ids.append(n->get_uid()); + } + d[py::str(group)] = std::move(ids); + } + const std::string children_str = py::str(d).cast(); + return "DocNode(id: " + node.get_uid() + ", group: " + node._group_name + + ", content: " + node.get_text(lazyllm::MetadataMode::NONE) + + ") parent: " + node.get_parent_uid() + ", children: " + children_str; +} + +} // namespace + +void exportDocNode(py::module& m) { + py::enum_(m, "MetadataMode") + .value("ALL", lazyllm::MetadataMode::ALL) + .value("EMBED", lazyllm::MetadataMode::EMBED) + .value("LLM", lazyllm::MetadataMode::LLM) + .value("NONE", lazyllm::MetadataMode::NONE); + + auto metadata_map = py::bind_map(m, "DocNodeMetadataMap"); + metadata_map + .def("get", + [](lazyllm::DocNode::Metadata& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it == self.end()) return default_value; + return pyu::MetadataValueToPy(it->second); + }, + py::arg("key"), + py::arg("default") = py::none()) + .def("pop", + [](lazyllm::DocNode::Metadata& self, const std::string& key) { + auto it = self.find(key); + if (it == self.end()) throw py::key_error(py::str(key)); + py::object value = pyu::MetadataValueToPy(it->second); + self.erase(it); + return value; + }, + py::arg("key")) + .def("pop", + [](lazyllm::DocNode::Metadata& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it == self.end()) return default_value; + py::object value = pyu::MetadataValueToPy(it->second); + self.erase(it); + return value; + }, + py::arg("key"), + py::arg("default")) + .def("setdefault", + [](lazyllm::DocNode::Metadata& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it != self.end()) return pyu::MetadataValueToPy(it->second); + auto [inserted, ok] = self.emplace(key, pyu::PyToMetadataValue(default_value)); + (void)ok; + return pyu::MetadataValueToPy(inserted->second); + }, + py::arg("key"), + py::arg("default") = py::none()) + .def("copy", + [](const lazyllm::DocNode::Metadata& self) { + return MetadataToPy(self); + }) + .def("__copy__", + [](const lazyllm::DocNode::Metadata& self) { + return MetadataToPy(self); + }) + .def("__deepcopy__", + [](const lazyllm::DocNode::Metadata& self, const py::dict& memo) { + py::object copy = py::module_::import("copy"); + return copy.attr("deepcopy")(MetadataToPy(self), memo); + }, + py::arg("memo")) + .def("update", + [](lazyllm::DocNode::Metadata& self, py::object other, py::kwargs kwargs) { + if (!other.is_none()) { + py::dict d = py::dict(other); + for (auto item : d) { + const std::string key = py::cast(item.first); + self[key] = pyu::PyToMetadataValue(item.second); + } + } + for (auto item : kwargs) { + const std::string key = py::cast(item.first); + self[key] = pyu::PyToMetadataValue(item.second); + } + }, + py::arg("other") = py::none()) + .def("__eq__", + [](const lazyllm::DocNode::Metadata& self, py::object other) { + py::dict lhs = MetadataToPy(self); + if (py::isinstance(other) || py::hasattr(other, "items")) + return py::bool_(lhs.equal(py::dict(other))); + return py::bool_(false); + }, + py::is_operator()); + + auto children_map = py::bind_map(m, "DocNodeChildrenMap"); + children_map + .def("get", + [](lazyllm::DocNode::Children& self, const std::string& key, py::object default_value) -> py::object { + auto it = self.find(key); + if (it == self.end()) return default_value; + return py::cast(it->second); + }, + py::arg("key"), + py::arg("default") = py::none()) + .def("pop", + [](lazyllm::DocNode::Children& self, const std::string& key) { + auto it = self.find(key); + if (it == self.end()) throw py::key_error(py::str(key)); + py::list value = py::cast(it->second); + self.erase(it); + return value; + }, + py::arg("key")) + .def("pop", + [](lazyllm::DocNode::Children& self, const std::string& key, py::object default_value) -> py::object { + auto it = self.find(key); + if (it == self.end()) return default_value; + py::list value = py::cast(it->second); + self.erase(it); + return value; + }, + py::arg("key"), + py::arg("default")) + .def("setdefault", + [](lazyllm::DocNode::Children& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it != self.end()) return py::cast(it->second); + auto [inserted, ok] = self.emplace(key, NodeVectorFromPy(default_value)); + (void)ok; + return py::cast(inserted->second); + }, + py::arg("key"), + py::arg("default") = py::list()) + .def("copy", + [](const lazyllm::DocNode::Children& self) { + return ChildrenToPy(self); + }) + .def("__copy__", + [](const lazyllm::DocNode::Children& self) { + return ChildrenToPy(self); + }) + .def("__deepcopy__", + [](const lazyllm::DocNode::Children& self, const py::dict& memo) { + py::object copy = py::module_::import("copy"); + return copy.attr("deepcopy")(ChildrenToPy(self), memo); + }, + py::arg("memo")) + .def("update", + [](lazyllm::DocNode::Children& self, py::object other, py::kwargs kwargs) { + if (!other.is_none()) { + py::dict d = py::dict(other); + for (auto item : d) { + const std::string key = py::cast(item.first); + py::object value = py::reinterpret_borrow(item.second); + self[key] = NodeVectorFromPy(value); + } + } + for (auto item : kwargs) { + const std::string key = py::cast(item.first); + py::object value = py::reinterpret_borrow(item.second); + self[key] = NodeVectorFromPy(value); + } + }, + py::arg("other") = py::none()) + .def("__eq__", + [](const lazyllm::DocNode::Children& self, py::object other) { + py::dict lhs = ChildrenToPy(self); + if (py::isinstance(other) || py::hasattr(other, "items")) + return py::bool_(lhs.equal(py::dict(other))); + return py::bool_(false); + }, + py::is_operator()); + + py::class_>(m, "DocNode", py::dynamic_attr()) + .def(py::init(&init), + py::arg("uid") = py::none(), + py::arg("content") = py::none(), + py::arg("group") = py::none(), + py::arg("embedding") = py::none(), + py::arg("parent") = py::none(), + py::arg("store") = py::none(), + py::arg("node_groups") = py::none(), + py::arg("metadata") = py::none(), + py::arg("global_metadata") = py::none(), + py::arg("text") = py::none() + ) + .def_property("_uid", + [](const lazyllm::DocNode& node) { return node.get_uid(); }, + [](lazyllm::DocNode& node, const std::string& value) { node.set_uid(value); } + ) + .def_property("_group", + [](const lazyllm::DocNode& node) { return node._group_name; }, + [](lazyllm::DocNode& node, const std::string& value) { node._group_name = value; } + ) + .def_property("_content", + py::cpp_function([](const py::object& self) { + const auto& node = self.cast(); + const std::string text = std::string(node.get_text(lazyllm::MetadataMode::NONE)); + if (!IsJsonDocNode(self)) return py::cast(text); + try { + return pyu::LoadJson(text); + } catch (const py::error_already_set&) { + return py::cast(text); + } + }), + py::cpp_function([](py::object self, const py::object& content) { + auto& node = self.cast(); + const auto normalized = NormalizeContent(content); + if (!normalized) return; + if (const auto* content_str = std::get_if(&*normalized)) { + node.set_root_text(std::move(*content_str)); + return; + } + node.set_root_texts(std::get>(*normalized)); + }) + ) + .def_property("number", + [](const lazyllm::DocNode& node) { + auto it = node._metadata.find("lazyllm_store_num"); + if (it == node._metadata.end()) return 0; + if (auto* value = std::get_if(&it->second)) return *value; + if (auto* value = std::get_if(&it->second)) return static_cast(*value); + return 0; + }, + [](lazyllm::DocNode& node, int value) { + node._metadata[std::string("lazyllm_store_num")] = value; + } + ) + .def_property_readonly("text", [](const lazyllm::DocNode& node) { return std::string(node.get_text()); }) + .def_property_readonly("content_hash", [](const lazyllm::DocNode& node) { + return lazyllm::to_hex(node.get_text_hash()); + }) + .def_property("embedding", + [](const lazyllm::DocNode& node) { return node._embedding_vecs; }, + [](lazyllm::DocNode& node, const lazyllm::DocNode::EmbeddingVecs& v) { + node._embedding_vecs = v; + } + ) + .def_property("_parent", + [](const lazyllm::DocNode& node) { return node.get_parent_node(); }, + [](lazyllm::DocNode& node, const py::object& parent) { + if (parent.is_none() || py::isinstance(parent)) { + node.set_parent_node(nullptr); + return; + } + node.set_parent_node(parent.cast()); + }, + py::return_value_policy::reference + ) + .def_property("_children", + py::cpp_function([](lazyllm::DocNode& node) -> lazyllm::DocNode::Children& { + node.get_children(); + return node.get_children_ref(); + }, py::return_value_policy::reference_internal), + [](lazyllm::DocNode& node, const py::object& children) { + node.set_children(ChildrenFromPy(children)); + } + ) + .def_property_readonly("root_node", + [](const lazyllm::DocNode& node) { return node.get_root_node(); }, + py::return_value_policy::reference + ) + .def_property_readonly("is_root_node", + [](const lazyllm::DocNode& node) { return node.get_parent_node() == nullptr; } + ) + .def_property("_global_metadata", + py::cpp_function([](lazyllm::DocNode& node) -> lazyllm::DocNode::Metadata& { + return EnsureRootGlobalMetadata(node); + }, py::return_value_policy::reference_internal), + [](lazyllm::DocNode& node, const py::object& meta) { + EnsureRootGlobalMetadata(node) = MetadataFromPy(meta); + } + ) + .def_property("_metadata", + py::cpp_function([](lazyllm::DocNode& node) -> lazyllm::DocNode::Metadata& { + return node._metadata; + }, py::return_value_policy::reference_internal), + [](lazyllm::DocNode& node, const py::object& meta) { + node._metadata = MetadataFromPy(meta); + } + ) + .def_property("_excluded_embed_metadata_keys", + [](const lazyllm::DocNode& node) { + const auto keys = node.get_excluded_embed_metadata_keys(); + return std::vector(keys.begin(), keys.end()); + }, + [](lazyllm::DocNode& node, const py::object& keys_obj) { + std::set keys; + for (auto item : keys_obj) { + keys.insert(py::cast(item)); + } + node.set_excluded_embed_metadata_keys(keys); + } + ) + .def_property("_excluded_llm_metadata_keys", + [](const lazyllm::DocNode& node) { + const auto keys = node.get_excluded_llm_metadata_keys(); + return std::vector(keys.begin(), keys.end()); + }, + [](lazyllm::DocNode& node, const py::object& keys_obj) { + std::set keys; + for (auto item : keys_obj) { + keys.insert(py::cast(item)); + } + node.set_excluded_llm_metadata_keys(keys); + } + ) + .def("__getattr__", [](py::object self, const std::string& attr_name) -> py::object { + if (const char* alias = ResolveDocNodeAlias(attr_name)) { + return py::getattr(self, py::str(alias)); + } + throw py::attribute_error("DocNode has no attribute '" + attr_name + "'"); + }) + .def("__setattr__", [](py::object self, const std::string& attr_name, const py::object& value) { + if (kDocNodeReadonlyAliases.find(attr_name) != kDocNodeReadonlyAliases.end()) { + throw py::attribute_error("property '" + attr_name + "' of 'DocNode' object has no setter"); + } + const char* alias = ResolveDocNodeAlias(attr_name); + const char* target = alias ? alias : attr_name.c_str(); + py::object object_setattr = py::module_::import("builtins").attr("object").attr("__setattr__"); + object_setattr(self, py::str(target), value); + }) + .def_property("docpath", + [](const lazyllm::DocNode& node) { return node.get_doc_path(); }, + [](lazyllm::DocNode& node, const std::string& path) { node.set_doc_path(path); } + ) + .def_property("embedding_state", + [](const lazyllm::DocNode& node) { return node._pending_embedding_keys; }, + [](lazyllm::DocNode& node, const std::set& keys) { + node._pending_embedding_keys = keys; + } + ) + .def_property("relevance_score", + [](const lazyllm::DocNode& node) { return node._relevance_score; }, + [](lazyllm::DocNode& node, double score) { node._relevance_score = score; } + ) + .def_property("similarity_score", + [](const lazyllm::DocNode& node) { return node._similarity_score; }, + [](lazyllm::DocNode& node, double score) { node._similarity_score = score; } + ) + .def("get_children_str", [](const lazyllm::DocNode& node) { + py::dict d; + const auto children = node.get_children(); + for (const auto& [group, nodes] : children) { + py::list ids; + for (std::shared_ptr n : nodes) if (n) ids.append(n->get_uid()); + d[py::str(group)] = std::move(ids); + } + return py::str(d); + }) + .def("get_parent_id", &lazyllm::DocNode::get_parent_uid) + .def("__str__", &DocNodeToString) + .def("__repr__", [](const lazyllm::DocNode& node) { + py::object cfg = py::module_::import("lazyllm").attr("config"); + py::object mode = py::module_::import("lazyllm").attr("Mode"); + py::object cfg_mode = cfg.attr("__getitem__")("mode"); + if (py::bool_(cfg_mode.equal(mode.attr("Debug")))) + return DocNodeToString(node); + return std::string(""; + }) + .def("__eq__", &lazyllm::DocNode::operator==, py::is_operator()) + .def("__hash__", [](const lazyllm::DocNode& node) { + return static_cast(std::hash{}(node.get_uid())); + }) + .def("__getstate__", [](const lazyllm::DocNode& node) { + py::dict st; + st["_uid"] = node.get_uid(); + st["_content"] = node.get_text(lazyllm::MetadataMode::NONE); + st["_group"] = node._group_name; + st["_embedding"] = node._embedding_vecs; + st["_metadata"] = MetadataToPy(node._metadata); + st["_global_metadata"] = MetadataToPy(GetRootGlobalMetadata(node)); + st["_excluded_embed_metadata_keys"] = node.get_excluded_embed_metadata_keys(); + st["_excluded_llm_metadata_keys"] = node.get_excluded_llm_metadata_keys(); + st["_store"] = py::none(); + st["_node_groups"] = py::none(); + return st; + }) + .def("has_missing_embedding", [](const lazyllm::DocNode& node, const py::object& keys) { + if (py::isinstance(keys)) { + const auto key = keys.cast(); + const auto missing = node.embedding_keys_undone({key}); + return std::vector(missing.begin(), missing.end()); + } + std::set key_set; + for (auto item : keys) { + key_set.insert(py::cast(item)); + } + const auto missing = node.embedding_keys_undone(key_set); + return std::vector(missing.begin(), missing.end()); + }) + .def("do_embedding", [](lazyllm::DocNode& node, const py::object& embed) { + py::dict embed_dict = py::dict(embed); + const auto text = node.get_text(lazyllm::MetadataMode::EMBED); + for (auto item : embed_dict) { + const std::string key = py::cast(item.first); + py::object func = py::reinterpret_borrow(item.second); + py::object result = func(py::str(text)); + node.set_embedding_vec(key, result.cast>()); + } + }, py::arg("embed")) + .def("set_embedding", [](lazyllm::DocNode& node, const std::string& key, const std::vector& value) { + node.set_embedding_vec(key, value); + }) + .def("check_embedding_state", [](lazyllm::DocNode& node, const std::string& key) { + while (true) { + if (node._embedding_vecs.find(key) != node._embedding_vecs.end()) { + node._pending_embedding_keys.erase(key); + break; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + }) + .def("get_content", [](py::object self) { + return self.cast().get_text(lazyllm::MetadataMode::LLM); + }) + .def("get_metadata_str", [](py::object self, const py::object& mode) { + auto& node = self.cast(); + if (mode.is_none()) return node.get_metadata_string(lazyllm::MetadataMode::ALL); + return node.get_metadata_string(pyu::ParseMetadataMode(mode)); + }, py::arg("mode") = py::none()) + .def("get_text", [](py::object self, const py::object& metadata_mode) { + return self.cast().get_text(pyu::ParseMetadataMode(metadata_mode)); + }, py::arg("metadata_mode") = py::none()) + .def("to_dict", [](py::object self) { + auto& node = self.cast(); + py::dict d; + d["content"] = node.get_text(lazyllm::MetadataMode::NONE); + d["embedding"] = node._embedding_vecs; + d["metadata"] = MetadataToPy(node._metadata); + return d; + }) + .def("with_score", [](const lazyllm::DocNode& node, double score) { + lazyllm::DocNode out = node; + out._relevance_score = score; + return out; + }) + .def("with_sim_score", [](const lazyllm::DocNode& node, double score) { + lazyllm::DocNode out = node; + out._similarity_score = score; + return out; + }); +} diff --git a/csrc/binding/export_node_transform.cpp b/csrc/binding/export_node_transform.cpp new file mode 100644 index 000000000..e162c6d27 --- /dev/null +++ b/csrc/binding/export_node_transform.cpp @@ -0,0 +1,398 @@ +#include "lazyllm.hpp" + +#include "doc_node.hpp" +#include "node_transform.hpp" + +#include + +#include +#include + +#include +#include +#include +#include + +namespace { + +class PyNodeTransform : public lazyllm::NodeTransform { +public: + using lazyllm::NodeTransform::NodeTransform; + + std::vector transform(lazyllm::PDocNode node) const override { + py::gil_scoped_acquire gil; + py::function overload = py::get_override(static_cast(this), "transform"); + if (!overload) { + overload = py::get_override(static_cast(this), "forward"); + } + if (!overload) throw std::runtime_error("NodeTransform.transform is not implemented."); + + py::object result = overload(node); + if (!py::isinstance(result)) { + throw std::runtime_error("NodeTransform.transform must return a sequence."); + } + + std::vector out; + for (auto item : result) { + py::object obj = py::reinterpret_borrow(item); + if (obj.is_none()) continue; + out.emplace_back(obj.cast()); + } + return out; + } +}; + +py::object GetBaseModule() { + return py::module_::import("lazyllm.tools.rag.transform.base"); +} + +py::object GetRuleSetClass() { + return GetBaseModule().attr("RuleSet"); +} + +py::object GetContextClass() { + return GetBaseModule().attr("_Context"); +} + +py::object GetDocNodeClass() { + return py::module_::import("lazyllm.tools.rag.doc_node").attr("DocNode"); +} + +py::object GetRichDocNodeClass() { + return py::module_::import("lazyllm.tools.rag.doc_node").attr("RichDocNode"); +} + +bool IsDocNode(const py::object& obj) { + return py::isinstance(obj, GetDocNodeClass()); +} + +bool IsRichDocNode(const py::object& obj) { + return py::isinstance(obj, GetRichDocNodeClass()); +} + +struct TransformRuntimeState { + py::object rules = py::none(); + py::object on_match = py::none(); + py::object on_miss = py::none(); +}; + +std::unordered_map& GetTransformStates() { + // Intentionally leaked to avoid py::object teardown during Python finalization. + static auto* states = new std::unordered_map(); + return *states; +} + +TransformRuntimeState& GetTransformState(const py::object& self) { + auto* ptr = self.cast(); + auto& states = GetTransformStates(); + auto it = states.find(ptr); + if (it == states.end()) { + TransformRuntimeState state; + state.rules = GetRuleSetClass()(); + auto inserted = states.emplace(ptr, std::move(state)); + return inserted.first->second; + } + return it->second; +} + +py::dict GetNodeChildrenDict(const py::object& node) { + try { + return py::dict(node.attr("children")); + } catch (const py::error_already_set&) { + try { + return py::dict(node.attr("_children")); + } catch (const py::error_already_set&) { + return py::dict(); + } + } +} + +void SetNodeChildrenDict(const py::object& node, const py::dict& children) { + try { + node.attr("children") = children; + } catch (const py::error_already_set&) { + try { + node.attr("_children") = children; + } catch (const py::error_already_set&) { + } + } +} + +py::list GetRefNodes(const py::object& node, const py::list& ref_path) { + py::list current; + current.append(node); + for (auto key_obj : ref_path) { + const std::string key = py::cast(key_obj); + py::list next; + for (auto n_obj : current) { + py::object n = py::reinterpret_borrow(n_obj); + py::dict children = GetNodeChildrenDict(n); + py::object group_nodes = children.attr("get")(key, py::list()); + for (auto child : group_nodes) next.append(child); + } + current = next; + } + return current; +} + +py::object CallForward(const py::object& self, const py::object& node, const py::dict& kwargs) { + py::object forward = py::getattr(self, "forward", py::none()); + if (forward.is_none()) { + throw std::runtime_error("NodeTransform.forward is not implemented."); + } + if (kwargs.is_none() || kwargs.empty()) { + return forward(node); + } + return forward(node, **kwargs); +} + +void ExtendList(py::list& target, const py::list& src) { + for (auto item : src) target.append(item); +} + +} // namespace + +void exportNodeTransform(py::module& m) { + py::class_(m, "NodeTransform", py::dynamic_attr()) + .def(py::init([](int num_workers, + py::object rules, + bool /*return_trace*/, + py::kwargs /*kwargs*/) { + auto inst = std::make_unique(num_workers); + TransformRuntimeState state; + state.rules = rules.is_none() ? GetRuleSetClass()() : rules; + GetTransformStates()[inst.get()] = std::move(state); + return inst; + }), + py::arg("num_workers") = 0, + py::arg("rules") = py::none(), + py::arg("return_trace") = false + ) + .def("batch_forward", + [](py::object self, + py::object documents, + const std::string& node_group, + py::object ref_path, + py::kwargs kwargs) { + py::list docs; + if (py::isinstance(documents) || py::isinstance(documents)) { + for (auto item : documents) docs.append(py::reinterpret_borrow(item)); + } else { + docs.append(documents); + } + + py::list all_outputs; + const bool support_rich = py::bool_(py::getattr(self, "__support_rich__", py::bool_(false))); + + for (auto node_obj : docs) { + py::object node = py::reinterpret_borrow(node_obj); + py::dict children = GetNodeChildrenDict(node); + if (children.contains(py::str(node_group))) { + continue; + } + + py::list splits; + if (!ref_path.is_none()) { + py::list ref_nodes = GetRefNodes(node, py::cast(ref_path)); + if (ref_nodes.empty()) continue; + + py::dict forward_kwargs = py::dict(kwargs); + forward_kwargs["ref"] = ref_nodes; + if (support_rich) { + if (py::len(ref_nodes) == 1) { + splits = py::cast(CallForward(self, ref_nodes[0], forward_kwargs)); + } else { + py::object rich = GetRichDocNodeClass()(py::arg("nodes") = ref_nodes); + splits = py::cast(CallForward(self, rich, forward_kwargs)); + } + } else { + splits = py::list(); + for (auto ref_node_obj : ref_nodes) { + py::object ref_node = py::reinterpret_borrow(ref_node_obj); + py::list out = py::cast(CallForward(self, ref_node, forward_kwargs)); + ExtendList(splits, out); + } + } + } else { + if (IsRichDocNode(node) && !support_rich) { + splits = py::list(); + for (auto sub : node.attr("nodes")) { + py::object sub_node = py::reinterpret_borrow(sub); + py::list out = py::cast(CallForward(self, sub_node, kwargs)); + ExtendList(splits, out); + } + } else { + splits = py::cast(CallForward(self, node, kwargs)); + } + } + + for (auto s_obj : splits) { + py::object s = py::reinterpret_borrow(s_obj); + try { + s.attr("parent") = node; + } catch (const py::error_already_set&) { + } + py::setattr(s, "_group", py::str(node_group)); + } + children[py::str(node_group)] = splits; + SetNodeChildrenDict(node, children); + ExtendList(all_outputs, splits); + } + + return all_outputs; + }, + py::arg("documents"), + py::arg("node_group"), + py::arg("ref_path") = py::none(), + py::return_value_policy::reference + ) + .def("transform", + [](py::object self, py::object node, py::kwargs kwargs) -> py::object { + if (node.is_none()) return py::object(py::list()); + return CallForward(self, node, py::dict(kwargs)); + }, + py::arg("document") + ) + .def("forward", + [](py::object /*self*/, py::object /*node*/, py::kwargs /*kwargs*/) { + PyErr_SetString(PyExc_NotImplementedError, + "Subclasses must implement forward() to process a single DocNode or RichDocNode"); + throw py::error_already_set(); + }, + py::arg("node") + ) + .def("__call__", + [](py::object self, py::object node_or_nodes, py::kwargs kwargs) -> py::object { + py::list results; + py::object forward_single = self.attr("_forward_single"); + if (py::isinstance(node_or_nodes) || py::isinstance(node_or_nodes)) { + for (auto item : node_or_nodes) { + py::object node = py::reinterpret_borrow(item); + if (!IsDocNode(node)) { + throw py::type_error( + "__call__() expects DocNode objects, got non-DocNode in list."); + } + py::list out = py::cast(forward_single(node, **kwargs)); + ExtendList(results, out); + } + return py::object(results); + } + + if (!IsDocNode(node_or_nodes)) { + throw py::type_error("__call__() expects DocNode or RichDocNode."); + } + return forward_single(node_or_nodes, **kwargs); + } + ) + .def("_forward_single", + [](py::object self, py::object node, py::kwargs kwargs) -> py::object { + const bool support_rich = py::bool_(py::getattr(self, "__support_rich__", py::bool_(false))); + if (IsRichDocNode(node) && !support_rich) { + py::list out; + for (auto sub : node.attr("nodes")) { + py::object sub_node = py::reinterpret_borrow(sub); + py::list res = py::cast(CallForward(self, sub_node, py::dict(kwargs))); + ExtendList(out, res); + } + return py::object(out); + } + return CallForward(self, node, py::dict(kwargs)); + } + ) + .def("process", + [](py::object self, py::list nodes, py::object on_match, py::object on_miss) { + py::object rules = py::getattr(self, "_rules", py::none()); + if (rules.is_none()) rules = GetRuleSetClass()(); + + py::object match_handler = on_match; + py::object miss_handler = on_miss; + + py::object instance_match = py::getattr(self, "_on_match", py::none()); + py::object instance_miss = py::getattr(self, "_on_miss", py::none()); + if (match_handler.is_none()) { + match_handler = instance_match.is_none() + ? py::getattr(self, "_default_match_handler") + : instance_match; + } + if (miss_handler.is_none()) { + miss_handler = instance_miss.is_none() + ? py::getattr(self, "_default_miss_handler") + : instance_miss; + } + + py::object ctx = GetContextClass()(py::arg("total") = py::len(nodes)); + py::list results; + size_t i = 0; + for (auto node_obj : nodes) { + py::object node = py::reinterpret_borrow(node_obj); + ctx.attr("current_idx") = i++; + py::object match = rules.attr("first")(node); + py::object processed = match.is_none() + ? miss_handler(node, ctx) + : match_handler(node, match, ctx); + results.append(processed); + ctx.attr("prev_node") = node; + ctx.attr("prev_result") = processed; + } + return results; + }, + py::arg("nodes"), + py::arg("on_match") = py::none(), + py::arg("on_miss") = py::none() + ) + .def("_default_match_handler", + [](py::object /*self*/, py::object /*node*/, py::object matched, py::object /*ctx*/) { + py::tuple tup = matched.cast(); + return tup[1]; + } + ) + .def("_default_miss_handler", + [](py::object /*self*/, py::object node, py::object /*ctx*/) { + return node; + } + ) + .def( + "with_name", + [](py::object self, py::object name, bool copy) -> py::object { + if (name.is_none()) return self; + if (copy) { + try { + py::object copier = py::module_::import("copy").attr("copy"); + py::object new_self = copier(self); + new_self.attr("_name") = name; + return new_self; + } catch (const py::error_already_set&) { + // Fallback to in-place mutation if copy fails. + } + } + self.attr("_name") = name; + return self; + }, + py::arg("name"), + py::kw_only(), + py::arg("copy") = true, + py::return_value_policy::reference + ) + .def_readwrite("_name", &lazyllm::NodeTransform::_name) + .def_property("_number_workers", + &lazyllm::NodeTransform::worker_num, + &lazyllm::NodeTransform::set_worker_num + ) + .def_property("_rules", + [](py::object self) { return GetTransformState(self).rules; }, + [](py::object self, py::object value) { + GetTransformState(self).rules = value.is_none() ? GetRuleSetClass()() : value; + } + ) + .def_property("_on_match", + [](py::object self) { return GetTransformState(self).on_match; }, + [](py::object self, py::object value) { GetTransformState(self).on_match = value; } + ) + .def_property("_on_miss", + [](py::object self) { return GetTransformState(self).on_miss; }, + [](py::object self, py::object value) { GetTransformState(self).on_miss = value; } + ) + ; + + m.attr("NodeTransform").attr("__support_rich__") = py::bool_(false); +} diff --git a/csrc/binding/export_sentence_splitter.cpp b/csrc/binding/export_sentence_splitter.cpp new file mode 100644 index 000000000..bc3d4efac --- /dev/null +++ b/csrc/binding/export_sentence_splitter.cpp @@ -0,0 +1,224 @@ +#include "lazyllm.hpp" + +#include "sentence_splitter.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +class PySentenceSplitter final : public lazyllm::SentenceSplitter { +public: + using lazyllm::SentenceSplitter::SentenceSplitter; + + std::vector transform(const lazyllm::PDocNode node) const override { + PYBIND11_OVERRIDE( + std::vector, + lazyllm::SentenceSplitter, + transform, + node + ); + } +}; + +py::object GetBaseModule() { + return py::module_::import("lazyllm.tools.rag.transform.base"); +} + +py::object GetUnset() { + return GetBaseModule().attr("_UNSET"); +} + +bool IsUnset(const py::object& value) { + return value.is(GetUnset()); +} + +py::dict GetDefaultParams(const py::object& cls) { + if (!py::hasattr(cls, "_default_params")) { + cls.attr("_default_params") = py::dict(); + } + return cls.attr("_default_params").cast(); +} + +unsigned ResolveUnsigned( + const py::object& cls, + const std::string& param_name, + const py::object& value, + unsigned default_value +) { + if (value.is_none()) return default_value; + if (!IsUnset(value)) { + if (py::isinstance(value)) return value.cast(); + throw py::type_error(param_name + " must be an int"); + } + + py::dict defaults = GetDefaultParams(cls); + if (defaults.contains(param_name.c_str())) { + py::object v = py::reinterpret_borrow(defaults[param_name.c_str()]); + if (py::isinstance(v)) return v.cast(); + } + return default_value; +} + +std::string Trim(const std::string& input) { + size_t start = 0; + size_t end = input.size(); + while (start < end && std::isspace(static_cast(input[start]))) ++start; + while (end > start && std::isspace(static_cast(input[end - 1]))) --end; + return input.substr(start, end - start); +} + +void InitTiktokenTokenizer( + const py::object& py_self, + const std::string& encoding_name, + const py::object& model_name, + py::object allowed_special, + py::object disallowed_special +) { + (void)allowed_special; + (void)disallowed_special; + + std::string tokenizer_name = encoding_name; + if (!model_name.is_none()) { + if (!py::isinstance(model_name)) { + throw py::type_error("model_name must be a string"); + } + tokenizer_name = model_name.cast(); + } + auto tokenizer = std::make_shared(tokenizer_name); + + py::object encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer->encode(text); + }); + py::object decode = py::cpp_function([tokenizer](const std::vector& ids) { + py::bytes raw(tokenizer->decode(ids)); + return raw.attr("decode")("utf-8", "replace"); + }); + + py_self.attr("token_encoder") = encode; + py_self.attr("token_decoder") = decode; +} + +} // namespace + +void exportSentenceSplitter(py::module& m) { + py::class_< + lazyllm::SentenceSplitter, + lazyllm::TextSplitterBase, + PySentenceSplitter + >(m, "SentenceSplitter", py::dynamic_attr()) + .def(py::init([](py::object chunk_size, + py::object chunk_overlap, + py::object num_workers) { + py::object cls = py::module_::import("lazyllm.lazyllm_cpp").attr("SentenceSplitter"); + + unsigned cs = ResolveUnsigned(cls, "chunk_size", chunk_size, 1024); + unsigned ov = ResolveUnsigned(cls, "overlap", chunk_overlap, 200); + unsigned nw = ResolveUnsigned(cls, "num_workers", num_workers, 0); + + if (ov > cs) { + throw py::value_error( + "Got a larger chunk overlap (" + std::to_string(ov) + ") than chunk size (" + + std::to_string(cs) + "), should be smaller."); + } + if (cs == 0) { + throw py::value_error("chunk size should > 0 and overlap should >= 0"); + } + + return std::make_unique(cs, ov, nw, "gpt2"); + }), + py::arg("chunk_size") = py::none(), + py::arg("chunk_overlap") = py::none(), + py::arg("num_workers") = py::none() + ) + .def("_merge", + [](lazyllm::SentenceSplitter& self, py::list splits, int chunk_size) { + if (py::len(splits) == 0) return std::vector{}; + + struct SplitData { std::string text; bool is_sentence; int token_size; }; + std::vector data; + data.reserve(py::len(splits)); + for (auto item : splits) { + py::object s = py::reinterpret_borrow(item); + data.push_back({ + s.attr("text").cast(), + s.attr("is_sentence").cast(), + s.attr("token_size").cast() + }); + } + + std::vector chunks; + std::vector> cur_chunk; + int cur_chunk_len = 0; + bool is_chunk_new = true; + int overlap = self.overlap(); + + auto close_chunk = [&]() { + std::string joined; + joined.reserve(256); + for (const auto& part : cur_chunk) joined += part.first; + chunks.push_back(joined); + + auto last_chunk = cur_chunk; + cur_chunk.clear(); + cur_chunk_len = 0; + is_chunk_new = true; + + int overlap_len = 0; + for (auto it = last_chunk.rbegin(); it != last_chunk.rend(); ++it) { + if (overlap_len + it->second > overlap) break; + cur_chunk.push_back(*it); + overlap_len += it->second; + cur_chunk_len += it->second; + } + std::reverse(cur_chunk.begin(), cur_chunk.end()); + }; + + size_t i = 0; + while (i < data.size()) { + const auto& cur_split = data[i]; + if (cur_split.token_size > chunk_size) { + throw py::value_error("Single token exceeded chunk size"); + } + if (cur_chunk_len + cur_split.token_size > chunk_size && !is_chunk_new) { + close_chunk(); + } else { + if (cur_split.is_sentence + || cur_chunk_len + cur_split.token_size <= chunk_size + || is_chunk_new) { + cur_chunk_len += cur_split.token_size; + cur_chunk.push_back({cur_split.text, cur_split.token_size}); + ++i; + is_chunk_new = false; + } else { + close_chunk(); + } + } + } + + if (!is_chunk_new) { + std::string joined; + joined.reserve(256); + for (const auto& part : cur_chunk) joined += part.first; + chunks.push_back(joined); + } + + std::vector out; + for (const auto& chunk : chunks) { + std::string trimmed = Trim(chunk); + if (!trimmed.empty()) out.push_back(trimmed); + } + return out; + }, + py::arg("splits"), py::arg("chunk_size") + ); +} diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp new file mode 100644 index 000000000..9305f163a --- /dev/null +++ b/csrc/binding/export_text_splitter_base.cpp @@ -0,0 +1,537 @@ +#include "lazyllm.hpp" + +#include "text_splitter_base.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +using SplitFn = lazyllm::TextSplitterBase::SplitFn; + +class PyTextSplitterBase final : public lazyllm::TextSplitterBase { +public: + using lazyllm::TextSplitterBase::TextSplitterBase; + + std::vector transform(const lazyllm::PDocNode node) const override { + PYBIND11_OVERRIDE( + std::vector, + lazyllm::TextSplitterBase, + transform, + node + ); + } +}; + +py::object GetBaseModule() { + return py::module_::import("lazyllm.tools.rag.transform.base"); +} + +py::object GetUnset() { + return GetBaseModule().attr("_UNSET"); +} + +bool IsUnset(const py::object& value) { + return value.is(GetUnset()); +} + +py::dict GetDefaultParams(const py::object& cls) { + if (!py::hasattr(cls, "_default_params")) { + cls.attr("_default_params") = py::dict(); + } + return cls.attr("_default_params").cast(); +} + +unsigned ResolveUnsigned( + const py::object& cls, + const std::string& param_name, + const py::object& value, + unsigned default_value +) { + if (value.is_none()) return default_value; + if (!IsUnset(value)) { + if (py::isinstance(value)) return value.cast(); + throw py::type_error(param_name + " must be an int"); + } + + py::dict defaults = GetDefaultParams(cls); + if (defaults.contains(param_name.c_str())) { + py::object v = py::reinterpret_borrow(defaults[param_name.c_str()]); + if (py::isinstance(v)) return v.cast(); + } + return default_value; +} + +py::object GetSplitClass() { + return GetBaseModule().attr("_Split"); +} + +py::object MakeSplit(const std::string& text, bool is_sentence, int token_size) { + return GetSplitClass()(py::arg("text") = text, + py::arg("is_sentence") = is_sentence, + py::arg("token_size") = token_size); +} + +std::pair, bool> GetSplitsByFns(const std::string& text) { + py::object base = GetBaseModule(); + py::object split_keep = base.attr("split_text_keep_separator"); + py::object nltk = base.attr("nltk"); + py::object tokenizer = nltk.attr("tokenize").attr("PunktSentenceTokenizer")(); + + py::object splits_obj = split_keep(text, "\n\n\n"); + auto splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, true}; + + splits_obj = tokenizer.attr("tokenize")(text); + splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, true}; + + py::object re = base.attr("re"); + splits_obj = re.attr("findall")(py::str(R"([^,.;。?!]+[,.;。?!]?)"), text); + splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, false}; + + splits_obj = split_keep(text, " "); + splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, false}; + + splits_obj = py::list(py::str(text)); + splits = splits_obj.cast>(); + return {splits, false}; +} + +void InitTiktokenTokenizer( + const py::object& py_self, + const std::string& encoding_name, + const py::object& model_name, + py::object allowed_special, + py::object disallowed_special +); + +void EnsureTokenCodec(const py::object& self) { + py::dict obj_dict = py::getattr(self, "__dict__", py::dict()).cast(); + const bool has_encoder = + obj_dict.contains("token_encoder") && + !py::reinterpret_borrow(obj_dict["token_encoder"]).is_none(); + const bool has_decoder = + obj_dict.contains("token_decoder") && + !py::reinterpret_borrow(obj_dict["token_decoder"]).is_none(); + if (!has_encoder || !has_decoder) { + InitTiktokenTokenizer(self, "gpt2", py::none(), py::none(), py::str("all")); + } +} + +int TokenSize(const py::object& self, const std::string& text) { + EnsureTokenCodec(self); + py::object encoder = self.attr("token_encoder"); + py::object tokens = encoder(text); + return static_cast(py::len(tokens)); +} + +void ExtendList(py::list& target, const py::list& src) { + for (auto item : src) target.append(item); +} + +py::list SplitRecursive(const py::object& self, const std::string& text, int chunk_size) { + const int token_size = TokenSize(self, text); + if (token_size <= chunk_size) { + py::list out; + out.append(MakeSplit(text, true, token_size)); + return out; + } + + auto [text_splits, is_sentence] = GetSplitsByFns(text); + py::list results; + for (const auto& segment : text_splits) { + const int seg_token_size = TokenSize(self, segment); + if (seg_token_size <= chunk_size) { + results.append(MakeSplit(segment, is_sentence, seg_token_size)); + } else { + py::list sub = SplitRecursive(self, segment, chunk_size); + ExtendList(results, sub); + } + } + return results; +} + +void InitTiktokenTokenizer( + const py::object& py_self, + const std::string& encoding_name, + const py::object& model_name, + py::object allowed_special, + py::object disallowed_special +) { + (void)allowed_special; + (void)disallowed_special; + + std::string tokenizer_name = encoding_name; + if (!model_name.is_none()) { + if (!py::isinstance(model_name)) { + throw py::type_error("model_name must be a string"); + } + tokenizer_name = model_name.cast(); + } + auto tokenizer = std::make_shared(tokenizer_name); + + py::object encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer->encode(text); + }); + py::object decode = py::cpp_function([tokenizer](const std::vector& ids) { + py::bytes raw(tokenizer->decode(ids)); + return raw.attr("decode")("utf-8", "replace"); + }); + + py_self.attr("token_encoder") = encode; + py_self.attr("token_decoder") = decode; +} + +void InitTokenizerFromObject(const py::object& py_self, const py::object& tokenizer, bool huggingface) { + py::object encode; + py::object decode; + if (huggingface) { + encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer.attr("encode")(text, py::arg("add_special_tokens") = false); + }); + decode = py::cpp_function([tokenizer](const std::vector& ids) { + return tokenizer.attr("decode")(ids, py::arg("skip_special_tokens") = true); + }); + } else { + encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer.attr("encode")(text); + }); + decode = py::cpp_function([tokenizer](const std::vector& ids) { + return tokenizer.attr("decode")(ids); + }); + } + + py_self.attr("token_encoder") = encode; + py_self.attr("token_decoder") = decode; +} + +} // namespace + +void exportTextSpliterBase(py::module& m) { + py::class_( + m, "_TextSplitterBase", py::dynamic_attr()) + .def(py::init([](py::object chunk_size, + py::object overlap, + py::object num_workers) { + py::object cls = py::module_::import("lazyllm.lazyllm_cpp").attr("_TextSplitterBase"); + + unsigned cs = ResolveUnsigned(cls, "chunk_size", chunk_size, 1024); + unsigned ov = ResolveUnsigned(cls, "overlap", overlap, 200); + unsigned nw = ResolveUnsigned(cls, "num_workers", num_workers, 0); + + if (ov > cs) { + throw py::value_error( + "Got a larger chunk overlap (" + std::to_string(ov) + ") than chunk size (" + + std::to_string(cs) + "), should be smaller."); + } + if (cs == 0) { + throw py::value_error("chunk size should > 0 and overlap should >= 0"); + } + + return std::make_unique(cs, ov, nw, "gpt2"); + }), + py::arg("chunk_size") = py::none(), + py::arg("overlap") = py::none(), + py::arg("num_workers") = py::none() + ) + .def_property_readonly("_chunk_size", &lazyllm::TextSplitterBase::chunk_size) + .def_property_readonly("_overlap", &lazyllm::TextSplitterBase::overlap) + .def("__getattr__", + [](py::object self, const std::string& name) -> py::object { + if (name == "token_encoder" || name == "token_decoder") { + EnsureTokenCodec(self); + py::dict obj_dict = py::getattr(self, "__dict__", py::dict()).cast(); + if (obj_dict.contains(name.c_str())) { + return py::reinterpret_borrow(obj_dict[name.c_str()]); + } + } + throw py::attribute_error( + "'" + std::string(py::str(py::type::of(self).attr("__name__"))) + + "' object has no attribute '" + name + "'"); + }, + py::arg("name") + ) + .def("split_text", + [](lazyllm::TextSplitterBase& self, const std::string& text, int metadata_size) { + if (text.empty()) return std::vector{""}; + + const int chunk_size = self.chunk_size(); + const int effective_chunk_size = chunk_size - metadata_size; + if (effective_chunk_size <= 0) { + throw py::value_error( + "Metadata length (" + std::to_string(metadata_size) + ") is longer than chunk size (" + + std::to_string(chunk_size) + "). Consider increasing the chunk size or decreasing the size of " + "your metadata to avoid this."); + } else if (effective_chunk_size < 50) { + try { + py::object log = py::module_::import("lazyllm").attr("LOG"); + log.attr("warning")( + "Metadata length (" + std::to_string(metadata_size) + ") is close to chunk size (" + + std::to_string(chunk_size) + "). Resulting chunks are less than 50 tokens. " + "Consider increasing the chunk size or decreasing the size of your metadata to avoid this."); + } catch (const py::error_already_set&) { + } + } + + py::object py_self = py::cast(&self, py::return_value_policy::reference); + py::object splits = py_self.attr("_split")(text, effective_chunk_size); + py::object chunks = py_self.attr("_merge")(splits, effective_chunk_size); + return chunks.cast>(); + }, + py::arg("text"), py::arg("metadata_size") + ) + .def_static("split_text_keep_separator", + [](const std::string& text, const std::string& separator) { + auto views = lazyllm::TextSplitterBase::split_text_while_keeping_separator(text, separator); + std::vector out; + out.reserve(views.size()); + for (auto view : views) out.emplace_back(view); + return out; + }, + py::arg("text"), py::arg("separator") + ) + .def("from_tiktoken_encoder", + [](py::object self, + const std::string& encoding_name, + py::object model_name, + py::object allowed_special, + py::object disallowed_special, + py::kwargs /*kwargs*/) { + InitTiktokenTokenizer(self, encoding_name, model_name, allowed_special, disallowed_special); + return self; + }, + py::arg("encoding_name") = "gpt2", + py::arg("model_name") = py::none(), + py::arg("allowed_special") = py::none(), + py::arg("disallowed_special") = "all" + ) + .def("from_tiktoken_encoding", + [](py::object self, const std::string& encoding_name) { + InitTiktokenTokenizer(self, encoding_name, py::none(), py::none(), py::str("all")); + return self; + }, + py::arg("encoding_name") = "gpt2" + ) + .def("from_tokenizer", + [](py::object self, py::object tokenizer) { + InitTokenizerFromObject(self, tokenizer, false); + return self; + }, + py::arg("tokenizer") + ) + .def("from_huggingface_tokenizer", + [](py::object self, py::object tokenizer) { + InitTokenizerFromObject(self, tokenizer, true); + return self; + }, + py::arg("tokenizer") + ) + .def("set_split_fns", + [](py::object /*self*/, const py::object& /*split_fns*/, py::object /*sub_split_fns*/) { + return; + }, + py::arg("split_fns"), py::arg("sub_split_fns") = py::none() + ) + .def("add_split_fn", + [](py::object /*self*/, const py::object& /*split_fn*/, py::object /*index*/) { + return; + }, + py::arg("split_fn"), py::arg("index") = py::none() + ) + .def("clear_split_fns", + [](py::object /*self*/) { + return; + } + ) + .def("_token_size", + [](py::object self, const std::string& text) { + return TokenSize(self, text); + }, + py::arg("text") + ) + .def("_get_metadata_size", + [](py::object self, py::object node) { + py::object meta_mode = GetBaseModule().attr("MetadataMode"); + std::string embed = node.attr("get_metadata_str")(meta_mode.attr("EMBED")).cast(); + std::string llm = node.attr("get_metadata_str")(meta_mode.attr("LLM")).cast(); + return std::max(TokenSize(self, embed), TokenSize(self, llm)); + }, + py::arg("node") + ) + .def("_get_splits_by_fns", + [](py::object /*self*/, const std::string& text) { + auto [splits, is_sentence] = GetSplitsByFns(text); + return py::make_tuple(splits, is_sentence); + }, + py::arg("text") + ) + .def("_split", + [](py::object self, const std::string& text, int chunk_size) { + return SplitRecursive(self, text, chunk_size); + }, + py::arg("text"), py::arg("chunk_size") + ) + .def("_merge", + [](lazyllm::TextSplitterBase& self, py::list splits, int chunk_size) { + if (py::len(splits) == 0) return std::vector{}; + if (py::len(splits) == 1) { + py::object s = splits[0]; + return std::vector{py::cast(s.attr("text"))}; + } + + struct SplitData { std::string text; bool is_sentence; int token_size; }; + std::vector data; + data.reserve(py::len(splits)); + for (auto item : splits) { + py::object s = py::reinterpret_borrow(item); + data.push_back({ + s.attr("text").cast(), + s.attr("is_sentence").cast(), + s.attr("token_size").cast() + }); + } + + const int overlap = self.overlap(); + py::object py_self = py::cast(&self, py::return_value_policy::reference); + EnsureTokenCodec(py_self); + py::object encoder = py_self.attr("token_encoder"); + py::object decoder = py_self.attr("token_decoder"); + + if (data.back().token_size == chunk_size && overlap > 0) { + SplitData end_split = data.back(); + data.pop_back(); + auto tokens = encoder(end_split.text).cast>(); + const size_t half = tokens.size() / 2; + auto mid = static_cast::difference_type>(half); + std::vector prefix(tokens.begin(), tokens.begin() + mid); + std::vector suffix(tokens.begin() + mid, tokens.end()); + std::string p_text = decoder(prefix).cast(); + std::string n_text = decoder(suffix).cast(); + data.push_back({p_text, end_split.is_sentence, TokenSize(py_self, p_text)}); + data.push_back({n_text, end_split.is_sentence, TokenSize(py_self, n_text)}); + } + + SplitData end_split = data.back(); + std::vector result; + for (int idx = static_cast(data.size()) - 2; idx >= 0; --idx) { + SplitData start_split = data[static_cast(idx)]; + if (start_split.token_size <= overlap && end_split.token_size <= chunk_size - overlap) { + end_split.text = start_split.text + end_split.text; + end_split.is_sentence = start_split.is_sentence && end_split.is_sentence; + end_split.token_size += start_split.token_size; + continue; + } + + if (end_split.token_size > chunk_size) { + throw py::value_error( + "split token size (" + std::to_string(end_split.token_size) + ") is greater than chunk size (" + + std::to_string(chunk_size) + ")."); + } + + const int remaining_space = chunk_size - end_split.token_size; + const int overlap_len = std::min({overlap, remaining_space, start_split.token_size}); + if (overlap_len > 0) { + auto start_tokens = encoder(start_split.text).cast>(); + std::vector overlap_tokens(start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = decoder(overlap_tokens).cast(); + end_split.text = overlap_text + end_split.text; + end_split.token_size += overlap_len; + } + + result.insert(result.begin(), end_split.text); + end_split = start_split; + } + + result.insert(result.begin(), end_split.text); + return result; + }, + py::arg("splits"), py::arg("chunk_size") + ) + .def("forward", + [](py::object self, py::object node, py::kwargs /*kwargs*/) { + const std::string text = node.attr("get_text")().cast(); + const int metadata_size = self.attr("_get_metadata_size")(node).cast(); + py::object chunks = self.attr("split_text")(text, metadata_size); + + std::vector out; + py::object doc_cls = py::module_::import("lazyllm.tools.rag.doc_node").attr("DocNode"); + for (auto item : chunks) { + py::object chunk = py::reinterpret_borrow(item); + if (chunk.is_none()) continue; + if (py::isinstance(chunk, doc_cls)) { + out.push_back(chunk.cast()); + continue; + } + std::string chunk_text = chunk.cast(); + if (chunk_text.empty()) continue; + out.push_back(std::make_shared(std::move(chunk_text))); + } + return out; + }, + py::arg("node") + ) + .def("_get_param_value", + [](py::object self, const std::string& param_name, py::object value, py::object default_value) { + if (!IsUnset(value)) return value; + py::object cls = py::type::of(self); + py::dict defaults = GetDefaultParams(cls); + if (defaults.contains(param_name.c_str())) { + return py::reinterpret_borrow(defaults[param_name.c_str()]); + } + return default_value; + }, + py::arg("param_name"), py::arg("value"), py::arg("default") + ); + + py::object py_cls = m.attr("_TextSplitterBase"); + py::object builtins = py::module_::import("builtins"); + + py_cls.attr("_get_class_lock") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass) { + if (!py::hasattr(klass, "_default_params_lock")) { + klass.attr("_default_params_lock") = py::module_::import("threading").attr("RLock")(); + } + return klass.attr("_default_params_lock"); + }) + ); + + py_cls.attr("set_default") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass, py::kwargs kwargs) { + py::dict defaults = GetDefaultParams(klass); + for (auto item : kwargs) { + defaults[item.first] = item.second; + } + klass.attr("_default_params") = defaults; + }) + ); + + py_cls.attr("get_default") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass, py::object param_name) -> py::object { + py::dict defaults = GetDefaultParams(klass); + if (param_name.is_none()) return py::object(defaults); + const std::string key = param_name.cast(); + if (defaults.contains(key.c_str())) { + return py::reinterpret_borrow(defaults[key.c_str()]); + } + return py::none(); + }) + ); + + py_cls.attr("reset_default") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass) { + klass.attr("_default_params") = py::dict(); + }) + ); +} diff --git a/csrc/binding/lazyllm.cpp b/csrc/binding/lazyllm.cpp index 2789e8e1c..db701edda 100644 --- a/csrc/binding/lazyllm.cpp +++ b/csrc/binding/lazyllm.cpp @@ -1,21 +1,23 @@ #include "lazyllm.hpp" -#include "doc_node.h" +#include "document_store.hpp" +#include "doc_node.hpp" + +#include namespace py = pybind11; PYBIND11_MODULE(lazyllm_cpp, m) { m.doc() = "LazyLLM CPP Module."; - exportDoc(m); + exportAddDocStr(m); - // prevent document generation + // Prevent document generation py::options options; options.disable_function_signatures(); - // DocNode - py::class_(m, "DocNode") - .def(py::init<>()) - .def(py::init(), py::arg("text")) - .def("set_text", &lazyllm::DocNode::set_text, py::arg("text")) - .def("get_text", &lazyllm::DocNode::get_text); + // Export classes + exportDocNode(m); + exportNodeTransform(m); + exportTextSpliterBase(m); + exportSentenceSplitter(m); } diff --git a/csrc/binding/lazyllm.hpp b/csrc/binding/lazyllm.hpp index d248f6677..62a7564f7 100644 --- a/csrc/binding/lazyllm.hpp +++ b/csrc/binding/lazyllm.hpp @@ -1,6 +1,16 @@ #pragma once +#include +#include + +#include #include #include -void exportDoc(pybind11::module& m); +namespace py = pybind11; + +void exportAddDocStr(pybind11::module& m); +void exportDocNode(pybind11::module& m); +void exportNodeTransform(pybind11::module& m); +void exportTextSpliterBase(pybind11::module& m); +void exportSentenceSplitter(pybind11::module& m); diff --git a/csrc/cmake/tests.cmake b/csrc/cmake/tests.cmake index 3eb114d9e..ce18c6b4a 100644 --- a/csrc/cmake/tests.cmake +++ b/csrc/cmake/tests.cmake @@ -1,8 +1,6 @@ -include(FetchContent) - FetchContent_Declare( googletest - URL https://codeload.github.com/google/googletest/zip/refs/tags/release-1.12.1 + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip ) # Fix gtest version to maintain C++11 compatibility. set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) @@ -29,5 +27,15 @@ foreach (test_src ${LAZYLLM_TEST_SOURCES}) pybind11::headers Python3::Python ) - gtest_discover_tests(${test_name}) + gtest_add_tests( + TARGET ${test_name} + TEST_LIST discovered_tests + ) + + # Attach runtime env per discovered case so each test gets the same loader path. + if (LAZYLLM_TEST_RUNTIME_ENV AND discovered_tests) + set_tests_properties(${discovered_tests} PROPERTIES + ENVIRONMENT "${LAZYLLM_TEST_RUNTIME_ENV}" + ) + endif () endforeach () diff --git a/csrc/cmake/third_party.cmake b/csrc/cmake/third_party.cmake new file mode 100644 index 000000000..794f2d51e --- /dev/null +++ b/csrc/cmake/third_party.cmake @@ -0,0 +1,40 @@ +include(FetchContent) + +find_package(Python3 COMPONENTS Interpreter Development Development.Module REQUIRED) +find_package(pybind11 CONFIG REQUIRED) + +find_package(xxHash QUIET) +if (NOT TARGET xxhash) + FetchContent_Declare( + xxhash + GIT_REPOSITORY https://github.com/Cyan4973/xxHash.git + GIT_TAG v0.8.2 + ) + FetchContent_Populate(xxhash) + add_subdirectory(${xxhash_SOURCE_DIR}/cmake_unofficial ${xxhash_BINARY_DIR}) +endif() + +find_package(cpp_tiktoken QUIET) +if (NOT TARGET cpp_tiktoken) + # We only need cpp_tiktoken for in-tree usage; avoid exporting/installing it. + set(CPP_TIKTOKEN_INSTALL OFF CACHE BOOL "" FORCE) + set(CPP_TIKTOKEN_TESTING OFF CACHE BOOL "" FORCE) + FetchContent_Declare( + cpp_tiktoken + GIT_REPOSITORY https://github.com/gh-markt/cpp-tiktoken.git + GIT_TAG master + ) + FetchContent_MakeAvailable(cpp_tiktoken) +endif() + +find_package(utf8proc QUIET) +if (NOT TARGET utf8proc) + # We only need utf8proc for in-tree usage; avoid exporting/installing it. + set(UTF8PROC_INSTALL OFF CACHE BOOL "" FORCE) + FetchContent_Declare( + utf8proc + GIT_REPOSITORY https://github.com/JuliaStrings/utf8proc.git + GIT_TAG v2.9.0 + ) + FetchContent_MakeAvailable(utf8proc) +endif() diff --git a/csrc/core/include/adaptor_base.hpp b/csrc/core/include/adaptor_base.hpp new file mode 100644 index 000000000..602dcb2d3 --- /dev/null +++ b/csrc/core/include/adaptor_base.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include + +#if defined(__GNUC__) || defined(__clang__) +#define LAZYLLM_HIDDEN __attribute__((visibility("hidden"))) +#else +#define LAZYLLM_HIDDEN +#endif + +namespace lazyllm { + +class AdaptorBase { +public: + virtual ~AdaptorBase() = default; + virtual std::any call( + const std::string& func_name, + const std::unordered_map& args) const = 0; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp new file mode 100644 index 000000000..aa54c398c --- /dev/null +++ b/csrc/core/include/doc_node.hpp @@ -0,0 +1,216 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" +#include "adaptor_base.hpp" + +namespace lazyllm { + +enum class MetadataMode { ALL, EMBED, LLM, NONE }; + +class DocNode; +using PDocNode = std::shared_ptr; + +class DocNode { +public: + using MetadataVType = lazyllm::MetadataVType; + using Metadata = std::unordered_map; + using Children = std::unordered_map>; + using EmbeddingFun = std::function(const std::string&, const std::string&)>; + using EmbeddingVecs = std::unordered_map>; + + Metadata _metadata; + std::shared_ptr _p_global_metadata; + EmbeddingVecs _embedding_vecs; + std::set _pending_embedding_keys = {}; + double _relevance_score = .0; + double _similarity_score = .0; + std::string _group_name; + +private: + std::string_view _text_view; + std::shared_ptr _p_root_text = nullptr; + std::vector _root_texts = {}; + std::string _uid; + mutable size_t _text_hash = 0; + + std::set _excluded_embed_metadata_keys; + std::set _excluded_llm_metadata_keys; + + const DocNode* _p_parent_node = nullptr; + mutable Children _children; + std::shared_ptr _p_store = nullptr; + +public: + explicit DocNode( + const std::string_view& text_view, + const std::string& group_name = "", + const std::string& uid = "", + const DocNode* p_parent_node = nullptr, + const Metadata& metadata = {}, + const std::shared_ptr& global_metadata = {} + ) : + _group_name(group_name), + _uid(uid.empty() ? GenerateUUID() : uid), + _p_parent_node(p_parent_node), + _metadata(metadata), + _p_global_metadata(global_metadata) + { + set_text_view(text_view); + } + DocNode() : DocNode("", "") {} + explicit DocNode(std::string&& text, const std::shared_ptr& global_metadata = {}) + : DocNode(text, "", "", nullptr, {}, global_metadata) { + set_root_text(std::move(text)); + } + explicit DocNode(const char* text, const std::shared_ptr& global_metadata = {}) + : DocNode(std::string(text == nullptr ? "" : text), global_metadata) {} + + DocNode(const DocNode&) = default; + DocNode& operator=(const DocNode&) = default; + virtual ~DocNode() = default; + + size_t evaluate_text_hash() const { + return static_cast(XXH64(_text_view.data(), _text_view.size(), 0)); + } + bool is_children_group_exists(const std::string& group_name) const { + if (_children.empty()) return false; + return _children.find(group_name) != _children.end(); + } + + // Getter and Setter + const DocNode* get_root_node() const { + if (_p_parent_node == nullptr) return this; + return _p_parent_node->get_root_node(); + } + void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } + const std::string& get_uid() const { return _uid; } + void set_uid(const std::string& uid) { _uid = uid; } + const std::string_view& get_text_view() const { return _text_view; } + void set_text_view(const std::string_view& text_view) { + _text_view = text_view; + _text_hash = evaluate_text_hash(); + } + std::string get_metadata_string(MetadataMode mode = MetadataMode::ALL) const { + if (mode == MetadataMode::NONE) return ""; + + std::set valid_keys; + for (const auto& [key, _] : _metadata) valid_keys.insert(key); + + if (mode == MetadataMode::LLM) + valid_keys = SetDiff(valid_keys, _excluded_llm_metadata_keys); + else if (mode == MetadataMode::EMBED) + valid_keys = SetDiff(valid_keys, _excluded_embed_metadata_keys); + + std::vector kv_strings; + for (const std::string& key : valid_keys) + kv_strings.emplace_back(key + ": " + any_to_string(_metadata.at(key))); + + return JoinLines(kv_strings); + } + std::string get_text(MetadataMode mode = MetadataMode::NONE) const { + if (mode == MetadataMode::NONE) return std::string(_text_view); + const auto& metadata_string = get_metadata_string(mode); + if (metadata_string.empty()) return std::string(_text_view); + return metadata_string + "\n\n" + std::string(_text_view); + } + void set_root_text(const std::string& text) { + _p_root_text = std::make_shared(text); + set_text_view(*_p_root_text); + } + void set_root_text(std::string&& text) { + _p_root_text = std::make_shared(std::move(text)); + set_text_view(*_p_root_text); + } + void set_root_texts(const std::vector& texts) { set_root_text(JoinLines(texts)); } + size_t get_text_hash() const { return _text_hash; } + const DocNode* get_parent_node() const { return _p_parent_node; } + void set_parent_node(const DocNode* p_parent_node) { _p_parent_node = p_parent_node; } + Children get_children() const { + if (!_children.empty()) return _children; + if (_p_store == nullptr) return Children(); + _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); + return _children; + } + Children& get_children_ref() { + if (_children.empty() && _p_store != nullptr) { + _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); + } + return _children; + } + void set_children(const Children& children) { _children = children; } + void set_children_group( + const std::string& group_name, + const std::vector& children_group) { + _children[group_name] = children_group; + } + std::set get_excluded_embed_metadata_keys() const { + if (_p_parent_node == nullptr) return _excluded_embed_metadata_keys; + return SetUnion(_p_parent_node->get_excluded_embed_metadata_keys(), _excluded_embed_metadata_keys); + } + void set_excluded_embed_metadata_keys(const std::set& keys) { + _excluded_embed_metadata_keys = keys; + } + std::set get_excluded_llm_metadata_keys() const { + if (_p_parent_node == nullptr) return _excluded_llm_metadata_keys; + return SetUnion(_p_parent_node->get_excluded_llm_metadata_keys(), _excluded_llm_metadata_keys); + } + void set_excluded_llm_metadata_keys(const std::set& keys) { + _excluded_llm_metadata_keys = keys; + } + std::string get_doc_path() const { + const auto& value = get_root_node()->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_PATH)); + return std::get(value); + } + void set_doc_path(const std::string& path) { + get_root_node()->_p_global_metadata->operator[](std::string(RAGMetadataKeys::DOC_PATH)) = path; + } + auto get_children_uid() const { + auto children = get_children(); + std::unordered_map> children_uid; + for (auto& [group_name, nodes] : children) { + children_uid[group_name] = {}; + for (auto& node : nodes) + children_uid[group_name].push_back(node->get_uid()); + } + return children_uid; + } + std::string get_parent_uid() const { + auto parent = get_parent_node(); + if (parent == nullptr) return ""; + return parent->get_uid(); + } + std::set embedding_keys_undone(const std::set& keys_done) const { + if (keys_done.empty()) throw std::runtime_error("The ebmed_keys to be checked must be passed in.");; + std::set keys_undone; + for (const auto& key : keys_done) { + if (_embedding_vecs.find(key) == _embedding_vecs.end()) + keys_undone.insert(key); + } + return keys_undone; + } + void py_do_embedding(const std::unordered_map(const std::string&)>>& embedding_funcs) { + for (const auto& [key, func] : embedding_funcs) + _embedding_vecs[key] = func(get_text(MetadataMode::EMBED)); + } + void set_embedding_vec(const std::string& key, const std::vector& embedding_vec) { + _embedding_vecs[key] = embedding_vec; + } + + bool operator==(const DocNode& other) const { return _uid == other._uid; } + bool operator!=(const DocNode& other) const { return _uid != other._uid; } +}; + +} // namespace lazyllm diff --git a/csrc/core/include/map_params.hpp b/csrc/core/include/map_params.hpp new file mode 100644 index 000000000..7b75cebba --- /dev/null +++ b/csrc/core/include/map_params.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace lazyllm { + +class MapParams { +public: + using MapType = std::unordered_map; + + template + T get_param_value( + const std::string_view& param_name, + const std::optional& value, + const T& default_value) const + { + if (value.has_value()) return *value; + std::lock_guard guard(_lock); + auto it = _params.find(std::string(param_name)); + if (it != _params.end()) return std::any_cast(it->second); + + // Don't set default value here to avoid parameter pollution. + return default_value; + } + + template + void set_default(const std::string_view& param_name, T value) { + std::lock_guard guard(_lock); + _params[std::string(param_name)] = std::any(value); + } + + void set_default(const MapType& updates) { + std::lock_guard guard(_lock); + for (const auto& entry : updates) { + _params[entry.first] = entry.second; + } + } + + MapType get_default() const { + std::lock_guard guard(_lock); + return _params; + } + + template + std::optional get_default(const std::string& param_name) const { + std::lock_guard guard(_lock); + auto it = _params.find(param_name); + if (it == _params.end()) return std::nullopt; + return std::any_cast(it->second); + } + + void reset_default() { + std::lock_guard guard(_lock); + _params.clear(); + } + +private: + mutable std::mutex _lock; + MapType _params; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/node_transform.hpp b/csrc/core/include/node_transform.hpp new file mode 100644 index 000000000..ca6a8f195 --- /dev/null +++ b/csrc/core/include/node_transform.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "doc_node.hpp" +#include "thread_pool.hpp" + +namespace lazyllm { + +class NodeTransform { +public: + std::string _name = ""; + + explicit NodeTransform(int worker_num = 0) : _worker_num(worker_num) {} + virtual ~NodeTransform() = default; + + virtual std::vector transform(const PDocNode) const = 0; + std::vector operator()(const PDocNode node) const { return transform(node); } + std::vector batch_forward(std::vector& nodes, const std::string& node_group_name); + + int worker_num() const { return _worker_num; } + void set_worker_num(int worker_num) { _worker_num = worker_num; } + +private: + std::vector forward(PDocNode node, const std::string& node_group_name); + + std::mutex _lock; + int _worker_num = 0; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/sentence_splitter.hpp b/csrc/core/include/sentence_splitter.hpp new file mode 100644 index 000000000..8963cdd9b --- /dev/null +++ b/csrc/core/include/sentence_splitter.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "text_splitter_base.hpp" + +namespace lazyllm { + +class SentenceSplitter : public TextSplitterBase { +public: + explicit SentenceSplitter( + std::optional chunk_size, + std::optional chunk_overlap, + std::optional worker_num, + const std::string& encoding_name = "gpt2") + : TextSplitterBase(chunk_size, chunk_overlap, worker_num, encoding_name) {} + +protected: + std::vector merge_chunks(const std::vector& splits, int chunk_size) const override; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp new file mode 100644 index 000000000..20e30f86c --- /dev/null +++ b/csrc/core/include/text_splitter_base.hpp @@ -0,0 +1,101 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "doc_node.hpp" +#include "map_params.hpp" +#include "node_transform.hpp" +#include "tokenizer.hpp" + +namespace lazyllm { + +class TextSplitterBase : public NodeTransform { +public: + using SplitFn = std::function(const std::string&)>; + static MapParams _default_params; + + explicit TextSplitterBase( + std::optional chunk_size = std::nullopt, + std::optional overlap = std::nullopt, + std::optional worker_num = std::nullopt, + const std::string& encoding_name = "gpt2") + : NodeTransform(_default_params.get_param_value("worker_num", worker_num, 0)), + _chunk_size(_default_params.get_param_value("chunk_size", chunk_size, 1024)), + _overlap(_default_params.get_param_value("overlap", overlap, 200)) + { + if (_overlap >= _chunk_size) throw std::runtime_error("'overlap' should be less than 'chunk_size'."); + if (_chunk_size == 0) throw std::runtime_error("'chunk_size' should > 0"); + + _tokenizer = std::make_shared(encoding_name); + } + + std::vector transform(PDocNode node) const override { + if (node == nullptr) return {}; + auto chunks = split_text(node->get_text_view(), get_node_metadata_size(*node)); + std::vector nodes; + nodes.reserve(chunks.size()); + for (auto& chunk : chunks) + nodes.push_back(std::make_shared(std::move(chunk))); + return nodes; + } + + std::vector split_text(const std::string_view& view, int metadata_size) const; + static std::vector split_text_while_keeping_separator( + const std::string_view& text, + const std::string_view& separator); + + TextSplitterBase& from_tiktoken_encoder( + const std::string& encoding_name = "gpt2", + const std::optional& model_name = std::nullopt) + { + _tokenizer = std::make_shared(model_name.value_or(encoding_name)); + return *this; + } + + void set_tokenizer(std::shared_ptr tokenizer) { _tokenizer = std::move(tokenizer); } + int chunk_size() const { return _chunk_size; } + int overlap() const { return _overlap; } + + virtual void set_split_functions( + const std::vector&, + const std::optional>& = std::nullopt) {} + virtual void add_split_function(const SplitFn&, const std::optional& = std::nullopt) {} + virtual void clear_split_functions() {} + +protected: + virtual std::vector split_recursive(const std::string_view& view, const int chunk_size) const; + virtual std::vector merge_chunks(const std::vector& splits, int chunk_size) const; + +private: + std::tuple, bool> split_by_functions(const std::string_view& text) const; + + int get_node_metadata_size(const DocNode& node) const { + return std::max( + get_token_size(node.get_metadata_string(MetadataMode::EMBED)), + get_token_size(node.get_metadata_string(MetadataMode::LLM))); + } + + int get_token_size(const std::string_view& view) const { + if (view.empty()) return 0; + return static_cast(_tokenizer->encode(view).size()); + } + +protected: + int _chunk_size; + int _overlap; + std::shared_ptr _tokenizer; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/thread_pool.hpp b/csrc/core/include/thread_pool.hpp new file mode 100644 index 000000000..ca9c2a9c8 --- /dev/null +++ b/csrc/core/include/thread_pool.hpp @@ -0,0 +1,99 @@ +// https://github.com/progschj/ThreadPool + +#ifndef THREAD_POOL_H +#define THREAD_POOL_H +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false) +{ + for(size_t i = 0;i task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this]{ return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if(stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task](){ (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for(std::thread &worker: workers) + worker.join(); +} + +#endif diff --git a/csrc/core/include/tokenizer.hpp b/csrc/core/include/tokenizer.hpp new file mode 100644 index 000000000..a7f2aada8 --- /dev/null +++ b/csrc/core/include/tokenizer.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +class Tokenizer { +public: + virtual ~Tokenizer() = default; + virtual std::vector encode(const std::string_view& view) const = 0; + virtual std::string decode(const std::vector& token_ids) const = 0; +}; + +class FallbackByteTokenizer final : public Tokenizer { +public: + FallbackByteTokenizer() = default; + ~FallbackByteTokenizer() override = default; + + std::vector encode(const std::string_view& view) const override { + std::vector token_ids; + token_ids.reserve(view.size()); + for (unsigned char ch : view) { + token_ids.push_back(static_cast(ch)); + } + return token_ids; + } + + std::string decode(const std::vector& token_ids) const override { + std::string text; + text.reserve(token_ids.size()); + for (int id : token_ids) { + text.push_back(static_cast(id & 0xFF)); + } + return text; + } +}; + +class TiktokenTokenizer final : public Tokenizer { +public: + TiktokenTokenizer() = delete; + explicit TiktokenTokenizer(LanguageModel model) + : _encoding(load_encoding(model)) {} + + explicit TiktokenTokenizer(std::string_view encoding_name) + : TiktokenTokenizer(parse_tiktoken_model(encoding_name)) {} + + ~TiktokenTokenizer() override = default; + + std::vector encode(const std::string_view& view) const override { + return _encoding->encode(std::string(view)); // TODO refactor to string_view + } + + std::string decode(const std::vector& token_ids) const override { + return _encoding->decode(token_ids); + } + +private: + class FilePathResourceReader final : public IResourceReader { + public: + explicit FilePathResourceReader(std::filesystem::path resource_path) + : resource_path_(std::move(resource_path)) {} + + std::vector readLines() override { + std::ifstream file(resource_path_); + if (!file.is_open()) { + throw std::runtime_error("Embedded resource '" + resource_path_.string() + "' not found."); + } + std::string line; + std::vector lines; + while (std::getline(file, line)) lines.push_back(line); + return lines; + } + + private: + std::filesystem::path resource_path_; + }; + + static std::string resource_name(LanguageModel model) { + switch (model) { + case LanguageModel::R50K_BASE: return "r50k_base.tiktoken"; + case LanguageModel::P50K_BASE: return "p50k_base.tiktoken"; + case LanguageModel::P50K_EDIT: return "p50k_base.tiktoken"; + case LanguageModel::CL100K_BASE: return "cl100k_base.tiktoken"; + case LanguageModel::O200K_BASE: return "o200k_base.tiktoken"; + case LanguageModel::QWEN_BASE: return "qwen.tiktoken"; + } + throw std::runtime_error("Unknown language model"); + } + + static std::shared_ptr load_encoding(LanguageModel model) { + try { + return GptEncoding::get_encoding(model); + } catch (const std::exception&) { + const std::filesystem::path repo_root = + std::filesystem::path(__FILE__).parent_path().parent_path().parent_path().parent_path(); + const std::string file_name = resource_name(model); + const std::vector candidates = { + repo_root / "build" / "tokenizers" / file_name, + repo_root / "tokenizers" / file_name, + std::filesystem::current_path() / "build" / "tokenizers" / file_name, + std::filesystem::current_path() / "tokenizers" / file_name + }; + for (const auto& path : candidates) { + if (!std::filesystem::exists(path)) continue; + FilePathResourceReader reader(path); + return GptEncoding::get_encoding(model, &reader); + } + throw; + } + } + + static bool has_prefix(std::string_view value, std::string_view prefix) { + return value.size() >= prefix.size() && value.compare(0, prefix.size(), prefix) == 0; + } + + static LanguageModel parse_tiktoken_model(std::string_view name) { + if (name.empty()) return LanguageModel::R50K_BASE; + + // Model-name aliases used by Python tiktoken.encoding_for_model. + if (name == "gpt-3.5-turbo" || has_prefix(name, "gpt-3.5-turbo-")) return LanguageModel::CL100K_BASE; + if (name == "gpt-4" || has_prefix(name, "gpt-4-")) return LanguageModel::CL100K_BASE; + if (name == "text-embedding-ada-002") return LanguageModel::CL100K_BASE; + if (name == "text-embedding-3-small" || name == "text-embedding-3-large") return LanguageModel::CL100K_BASE; + if (name == "gpt-4o" || has_prefix(name, "gpt-4o-")) return LanguageModel::O200K_BASE; + if (name == "gpt-4.1" || has_prefix(name, "gpt-4.1-")) return LanguageModel::O200K_BASE; + if (name == "gpt-4.5" || has_prefix(name, "gpt-4.5-")) return LanguageModel::O200K_BASE; + if (name == "o1" || has_prefix(name, "o1-")) return LanguageModel::O200K_BASE; + if (name == "o3" || has_prefix(name, "o3-")) return LanguageModel::O200K_BASE; + if (name == "o4-mini" || has_prefix(name, "o4-mini-")) return LanguageModel::O200K_BASE; + + if (name == "gpt2" || name == "r50k_base" || name == "r50k") return LanguageModel::R50K_BASE; + if (name == "p50k_base" || name == "p50k") return LanguageModel::P50K_BASE; + if (name == "p50k_edit") return LanguageModel::P50K_EDIT; + if (name == "cl100k_base" || name == "cl100k") return LanguageModel::CL100K_BASE; + if (name == "o200k_base" || name == "o200k") return LanguageModel::O200K_BASE; + if (name == "qwen_base" || name == "qwen") return LanguageModel::QWEN_BASE; + + throw std::runtime_error( + "Unknown tiktoken encoding/model name: " + std::string(name) + + ". Expected one of: gpt2, r50k_base, p50k_base, p50k_edit, cl100k_base, o200k_base, qwen_base." + + "(Case sensitive)"); + } + +private: + std::shared_ptr _encoding; +}; diff --git a/csrc/core/include/unicode_processor.hpp b/csrc/core/include/unicode_processor.hpp new file mode 100644 index 000000000..05ef9c268 --- /dev/null +++ b/csrc/core/include/unicode_processor.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace lazyllm { + +class UnicodeProcessor { +public: + UnicodeProcessor(const std::string_view& text) : _text(text), _text_len(text.size()) {} + std::vector split_to_chars() const; + std::vector split_by_punctuation() const; + +private: + using Utf8Visitor = std::function; + + void for_each_utf8_unit(const Utf8Visitor& visitor) const; + static bool is_sentence_punctuation(char32_t codepoint) { + return std::find(kPunctuationCodepoints.begin(), kPunctuationCodepoints.end(), + codepoint) != kPunctuationCodepoints.end(); + } + + static const std::array kPunctuationCodepoints; + std::string_view _text; + size_t _text_len = 0; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp new file mode 100644 index 000000000..bf9bb1f19 --- /dev/null +++ b/csrc/core/include/utils.hpp @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace lazyllm { + +using MetadataVType = std::variant< + std::string, std::vector, + int, std::vector, + double, std::vector, + std::any +>; + +struct RAGMetadataKeys { + inline static constexpr std::string_view KB_ID = "kb_id"; + inline static constexpr std::string_view DOC_ID = "docid"; + inline static constexpr std::string_view DOC_PATH = "lazyllm_doc_path"; + inline static constexpr std::string_view DOC_FILE_NAME = "file_name"; + inline static constexpr std::string_view DOC_FILE_TYPE = "file_type"; + inline static constexpr std::string_view DOC_FILE_SIZE = "file_size"; + inline static constexpr std::string_view DOC_CREATION_DATE = "creation_date"; + inline static constexpr std::string_view DOC_LAST_MODIFIED_DATE = "last_modified_date"; + inline static constexpr std::string_view DOC_LAST_ACCESSED_DATE = "last_accessed_date"; +}; + +inline std::string JoinLines(const std::vector& lines, char delim = '\n') { + if (lines.empty()) return {}; + std::string out = lines.front(); + for (size_t i = 1; i < lines.size(); ++i) { + out += delim; + out += lines[i]; + } + return out; +} + +template +std::vector ConcatVector(const std::vector& l, const std::vector& r) { + std::vector out; + out.reserve(l.size() + r.size()); + out.insert(out.end(), l.begin(), l.end()); + out.insert(out.end(), r.begin(), r.end()); + return out; +} + +template +std::set SetUnion(const std::set& l, const std::set& r) { + std::set out; + std::set_union(l.begin(), l.end(), r.begin(), r.end(), std::inserter(out, out.begin())); + return out; +} + +template +std::set SetDiff(const std::set& l, const std::set& r) { + std::set out; + std::set_difference(l.begin(), l.end(), r.begin(), r.end(), std::inserter(out, out.begin())); + return out; +} + +inline std::string to_hex(size_t v) { + std::ostringstream oss; + oss << std::hex << v; + return oss.str(); +} + +inline std::string GenerateUUID() { + static const char HEX_CHAR[] = "0123456789abcdef"; + static const int SEGS[] = {8, 4, 4, 4, 12}; + + // Single static generator per thread. + static thread_local std::mt19937 GEN(std::random_device{}()); + static thread_local std::uniform_int_distribution DIST(0, 15); + + std::string out; + out.reserve(36); + for (int segLength : SEGS) { + for (int i = 0; i < segLength; ++i) + out.push_back(HEX_CHAR[DIST(GEN)]); + if (segLength < 12) + out.push_back('-'); + } + return out; +} + +std::string any_to_string(const MetadataVType& value); + +inline bool is_adjacent(const std::string_view& left, const std::string_view& right) { + return left.data() + left.size() == right.data(); +} + +struct ChunkView { + std::string_view view; + bool is_sentence = false; + int token_size = 0; +}; + +struct Chunk { + std::string text; + bool is_sentence = false; + int token_size = 0; + + Chunk& operator+=(const Chunk& r) { + text += r.text; + is_sentence = is_sentence && r.is_sentence; + token_size += r.token_size; + return *this; + } +}; +} // namespace lazyllm diff --git a/csrc/core/src/doc_node.cpp b/csrc/core/src/doc_node.cpp new file mode 100644 index 000000000..031d608c5 --- /dev/null +++ b/csrc/core/src/doc_node.cpp @@ -0,0 +1 @@ +#include "doc_node.hpp" diff --git a/csrc/core/src/node_transform.cpp b/csrc/core/src/node_transform.cpp new file mode 100644 index 000000000..8d69eea9b --- /dev/null +++ b/csrc/core/src/node_transform.cpp @@ -0,0 +1,45 @@ +#include "node_transform.hpp" + +namespace lazyllm { + +std::vector NodeTransform::batch_forward( + std::vector& nodes, const std::string& node_group_name +) { + std::vector whole_nodes; + if (nodes.empty()) return whole_nodes; + + if (_worker_num > 0) { + ThreadPool pool(static_cast(_worker_num)); + std::vector>> futures; + futures.reserve(nodes.size()); + for (auto node : nodes) { + futures.emplace_back(pool.enqueue( + [this, node, node_group_name]() { return forward(node, node_group_name); })); + } + + for (auto& fut : futures) { + auto parts = fut.get(); + whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); + } + } else { + for (auto node : nodes) { + auto parts = forward(node, node_group_name); + whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); + } + } + + return whole_nodes; +} + +std::vector NodeTransform::forward(PDocNode node, const std::string& node_group_name) { + auto p_nodes = transform(node); + for (auto& p_node : p_nodes) { + p_node->set_parent_node(&*node); + p_node->_group_name = node_group_name; + } + node->set_children_group(node_group_name, p_nodes); + + return p_nodes; +} + +} // namespace lazyllm diff --git a/csrc/core/src/sentence_splitter.cpp b/csrc/core/src/sentence_splitter.cpp new file mode 100644 index 000000000..3b9931903 --- /dev/null +++ b/csrc/core/src/sentence_splitter.cpp @@ -0,0 +1,59 @@ +#include "sentence_splitter.hpp" + +#include +#include + +namespace lazyllm { + +std::string join_views( + size_t string_size, + std::vector::const_iterator begin, + const std::vector::const_iterator& end +) { + std::string out; + out.reserve(string_size); + while(begin != end) { + out.append(begin->view); + ++begin; + } + return out; +} + +std::vector SentenceSplitter::merge_chunks(const std::vector& chunks, int chunk_size) const { + std::vector out; + + auto iLeft = chunks.begin(); + auto iRight = chunks.begin(); + const auto& iEnd = chunks.end(); + int window_token_sum = 0; + size_t string_size = 0; + + while (iRight != iEnd) { + if (iRight->token_size > chunk_size) + throw std::runtime_error("Chunk size is too big."); + + // Grow right edge to the largest window under chunk_size. + while (iRight != iEnd && window_token_sum + iRight->token_size <= chunk_size) { + window_token_sum += iRight->token_size; + string_size += iRight->view.size(); + ++iRight; + } + + // Merge chunks witin window. + out.push_back(join_views(string_size, iLeft, iRight)); + + // Shrink left edge to select overlap of next merge. + while (iRight != iEnd && iLeft != iRight && ( + window_token_sum > _overlap || window_token_sum + iRight->token_size > chunk_size + )) { + window_token_sum -= iLeft->token_size; + string_size -= iLeft->view.size(); + ++iLeft; + } + // Now window contains only overlap. + } + + return out; +} + +} // namespace lazyllm diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp new file mode 100644 index 000000000..1ae75877d --- /dev/null +++ b/csrc/core/src/text_splitter_base.cpp @@ -0,0 +1,194 @@ +#include "text_splitter_base.hpp" +#include "unicode_processor.hpp" + +namespace lazyllm { + +MapParams TextSplitterBase::_default_params{}; + +/* + * split_text + * ---------- + * Purpose: + * 1) Validate chunk budget after accounting for metadata tokens. + * 2) Recursively split the original text view into token-bounded SplitUnit pieces. + * 3) Merge the pieces into final chunk strings with overlap behavior aligned to Python implementation. + * + * Flow: + * 1) Compute effective_chunk_size = chunk_size - metadata_size. + * 2) Reject invalid/too-small budgets. + * 3) Call split_recursive(...) to produce SplitUnit sequence. + * 4) Call merge_chunks(...) to build final std::string chunks. + * + * Notes: + * - This function returns std::string chunks intentionally because current tokenizer + * encode/decode materializes strings in the merge path. + * - Ownership is explicit here to avoid dangling string_view in downstream DocNode construction. + * + * TODO: + * - After tokenizer supports true string_view encode/decode, migrate this path back to + * std::vector and remove eager string materialization. + */ +std::vector TextSplitterBase::split_text(const std::string_view& view, int metadata_size) const { + if (view.empty()) return {}; + int effective_chunk_size = _chunk_size - metadata_size; + if (effective_chunk_size <= 0) { + throw std::runtime_error( + "Metadata length (" + std::to_string(metadata_size) + + ") is longer than chunk size (" + std::to_string(_chunk_size) + + "). Consider increasing the chunk size or decreasing the size of your metadata to avoid this."); + } + else if (effective_chunk_size < 50) { + throw std::runtime_error( + "Metadata length (" + std::to_string(metadata_size) + ") is close to chunk size (" + + std::to_string(_chunk_size) + "). Resulting chunks are less than 50 tokens. " + + "Consider increasing the chunk size or decreasing the size of " + + "your metadata to avoid this."); + } + auto splits = split_recursive(view, effective_chunk_size); + return merge_chunks(splits, effective_chunk_size); +} + +std::vector TextSplitterBase::split_recursive( + const std::string_view& view, const int chunk_size) const +{ + int token_size = get_token_size(view); + if (token_size <= chunk_size) return {ChunkView{view, true, token_size}}; + + auto [views, is_sentence] = split_by_functions(view); + std::vector splits; + for (const auto& segment_view : views) { + const int seg_token_size = get_token_size(segment_view); + if (seg_token_size == 0) continue; + if (seg_token_size <= chunk_size) { + splits.push_back({segment_view, is_sentence, seg_token_size}); + } else { + auto new_splits = split_recursive(segment_view, chunk_size); + splits.insert(splits.end(), new_splits.begin(), new_splits.end()); + } + } + return splits; +} + +std::tuple, bool> TextSplitterBase::split_by_functions(const std::string_view& text) const +{ + auto views = split_text_while_keeping_separator(text, "\n\n\n"); + if (views.size() > 1) return {views, true}; + + views = UnicodeProcessor(text).split_by_punctuation(); + if (views.size() > 1) return {views, false}; + + views = split_text_while_keeping_separator(text, " "); + if (views.size() > 1) return {views, false}; + + return {UnicodeProcessor(text).split_to_chars(), false}; +} + +std::vector TextSplitterBase::split_text_while_keeping_separator( + const std::string_view& text, + const std::string_view& separator) +{ + if (text.empty()) return {}; + else if (separator.empty()) return {text}; + + std::vector result; + size_t start = 0; + const size_t sep_len = separator.size(); + while (start < text.size()) { + const size_t idx = text.find(separator, start); + if (idx == std::string_view::npos) { + result.emplace_back(text.substr(start)); + break; + } + + if (idx == start) { + start += sep_len; + continue; + } + + result.emplace_back(text.substr(start, idx + sep_len - start)); + start = idx + sep_len; + } + return result; +} + +/** + * @brief Build final chunks from token-sized split units while preserving overlap semantics. + * + * @details + * 1) Convert input SplitUnit views to owned strings (MergedSplit) for safe concatenation. + * 2) If the tail split exactly matches chunk_size and overlap > 0: + * split it by token-halves via encode/decode, then push both halves back. + * 3) Iterate backward: + * Add previous split, or part of it, to current split as overlap. + * - If the previous split is small enough, prepend it fully. + * - Otherwise, prepend token-based overlap suffix from previous split. + * 4) Emit chunks in original order. + * + * @todo Replace eager string materialization once tokenizer encode/decode supports + * end-to-end zero-copy string_view operations. + */ +std::vector TextSplitterBase::merge_chunks(const std::vector& splits, int chunk_size) const { + if (splits.empty()) return {}; + + std::vector merged_splits; + merged_splits.reserve(splits.size() + 2); + for (const auto& split : splits) + merged_splits.push_back(Chunk{std::string(split.view), split.is_sentence, split.token_size}); + + if (merged_splits.size() == 1) return {merged_splits.front().text}; + + if (merged_splits.back().token_size == chunk_size && _overlap > 0) { + Chunk end_split = merged_splits.back(); + merged_splits.pop_back(); + + auto text_tokens = _tokenizer->encode(end_split.text); + const size_t half = text_tokens.size() / 2; + const auto split_it = text_tokens.begin() + static_cast::difference_type>(half); + std::vector prefix_tokens(text_tokens.begin(), split_it); + std::vector suffix_tokens(split_it, text_tokens.end()); + + std::string prefix_text = _tokenizer->decode(prefix_tokens); + std::string suffix_text = _tokenizer->decode(suffix_tokens); + merged_splits.push_back( + Chunk{prefix_text, end_split.is_sentence, get_token_size(prefix_text)}); + merged_splits.push_back( + Chunk{suffix_text, end_split.is_sentence, get_token_size(suffix_text)}); + } + + Chunk end_split = merged_splits.back(); + std::vector reversed_result; + reversed_result.reserve(merged_splits.size()); + for (int idx = static_cast(merged_splits.size()) - 2; idx >= 0; --idx) { + const Chunk& start_split = merged_splits[static_cast(idx)]; + if (start_split.token_size <= _overlap && end_split.token_size <= chunk_size - _overlap) { + end_split += start_split; + continue; + } + + if (end_split.token_size > chunk_size) { + throw std::runtime_error("split token size is greater than chunk size."); + } + + const int remaining_space = chunk_size - end_split.token_size; + const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); + if (overlap_len > 0) { + auto start_tokens = _tokenizer->encode(start_split.text); + std::vector overlap_tokens(start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = _tokenizer->decode(overlap_tokens); + + end_split = Chunk{ + overlap_text + end_split.text, + end_split.is_sentence, + end_split.token_size + overlap_len}; + } + + reversed_result.emplace_back(end_split.text); + end_split = start_split; + } + + reversed_result.emplace_back(end_split.text); + std::reverse(reversed_result.begin(), reversed_result.end()); + return reversed_result; +} + +} diff --git a/csrc/core/src/unicode_processor.cpp b/csrc/core/src/unicode_processor.cpp new file mode 100644 index 000000000..1abdb8c1c --- /dev/null +++ b/csrc/core/src/unicode_processor.cpp @@ -0,0 +1,102 @@ +#include "unicode_processor.hpp" +namespace lazyllm { + +const std::array UnicodeProcessor::kPunctuationCodepoints = { + U',', + U'.', + U';', + U'!', + U'\uFF0C', // , + U'\uFF1B', // ; + U'\u3002', // 。 + U'\uFF1F', // ? + U'\uFF01', // ! +}; + +void UnicodeProcessor::for_each_utf8_unit(const Utf8Visitor& visitor) const { + size_t i = 0; + while (i < _text_len) { + utf8proc_int32_t codepoint = -1; + const utf8proc_ssize_t n = utf8proc_iterate( + reinterpret_cast(_text.data() + i), + static_cast(_text_len - i), + &codepoint); + + if (n <= 0) { + i += 1; + continue; + } + + visitor(i, static_cast(n), codepoint); + i += static_cast(n); + } +} + +/** + * UTF-8 text processing has three distinct layers: + * 1) Byte: the storage unit in std::string_view; one code point uses 1-4 UTF-8 bytes. + * 2) Code point: a Unicode scalar value (for example U+0061, U+4E2D), decoded by utf8proc_iterate. + * 3) Grapheme cluster: one user-perceived character, which may contain multiple code points + * (for example base + combining mark, or emoji + VS/ZWJ sequences). + * + * This function splits by grapheme cluster, not by byte or code point: + * - for_each_utf8_unit() uses utf8proc_iterate to decode UTF-8 and provide + * code point, byte offset, and byte length. + * - utf8proc_grapheme_break_stateful(prev, codepoint, &state) determines whether + * there is a grapheme boundary between prev and the current code point. + * - When a boundary appears, we emit a string_view slice over byte range + * [cluster_start, offset). + * + * This keeps splitting zero-copy (string_view) while following Unicode grapheme-boundary rules. + */ +std::vector UnicodeProcessor::split_to_chars() const { + std::vector out; + if (_text.empty()) return out; + out.reserve(_text_len); // Grapheme count <= byte length + + size_t cluster_start = std::string_view::npos; + utf8proc_int32_t prev = -1; + utf8proc_int32_t state = 0; + + for_each_utf8_unit([&](size_t offset, size_t byte_len, utf8proc_int32_t codepoint) { + if (cluster_start == std::string_view::npos) { + cluster_start = offset; + } else if (utf8proc_grapheme_break_stateful(prev, codepoint, &state)) { + out.emplace_back(_text.substr(cluster_start, offset - cluster_start)); + cluster_start = offset; + } + prev = codepoint; + }); + + if (cluster_start != std::string_view::npos) { + out.emplace_back(_text.substr(cluster_start)); + } + return out; +} + +std::vector UnicodeProcessor::split_by_punctuation() const { + if (_text.empty()) return {}; + + std::vector out; + size_t chunk_start = std::string_view::npos; + + for_each_utf8_unit([&](size_t offset, size_t byte_len, char32_t codepoint) { + if (is_sentence_punctuation(codepoint)) { + if (chunk_start != std::string_view::npos) { + const size_t end = offset + byte_len; + out.push_back(_text.substr(chunk_start, end - chunk_start)); + chunk_start = std::string_view::npos; + } + } else if (chunk_start == std::string_view::npos) { + chunk_start = offset; + } + }); + + if (chunk_start != std::string_view::npos) { + out.emplace_back(_text.substr(chunk_start)); + } + if (out.empty()) out.emplace_back(_text); + return out; +} + +} // namespace lazyllm diff --git a/csrc/core/src/utils.cpp b/csrc/core/src/utils.cpp new file mode 100644 index 000000000..620fdb066 --- /dev/null +++ b/csrc/core/src/utils.cpp @@ -0,0 +1,83 @@ +#include "utils.hpp" + +#include + +namespace lazyllm { + +namespace { + +std::string ScalarToString(const std::string& value) { + return value; +} + +std::string ScalarToString(int value) { + return std::to_string(value); +} + +std::string ScalarToString(double value) { + std::ostringstream oss; + oss << value; + return oss.str(); +} + +template +std::string VectorToString(const std::vector& values) { + std::string out = "["; + for (size_t i = 0; i < values.size(); ++i) { + out += ScalarToString(values[i]); + if (i + 1 < values.size()) out += ","; + } + out += "]"; + return out; +} + +std::string AnyToString(const std::any& value) { + if (!value.has_value()) return ""; + const auto& t = value.type(); + + if (t == typeid(std::string)) return std::any_cast(value); + if (t == typeid(const char*)) return std::string(std::any_cast(value)); + if (t == typeid(char*)) return std::string(std::any_cast(value)); + if (t == typeid(bool)) return std::any_cast(value) ? "1" : "0"; + if (t == typeid(int)) return std::to_string(std::any_cast(value)); + if (t == typeid(long)) return std::to_string(std::any_cast(value)); + if (t == typeid(long long)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned int)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned long)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned long long)) return std::to_string(std::any_cast(value)); + if (t == typeid(float)) { + std::ostringstream oss; + oss << std::any_cast(value); + return oss.str(); + } + if (t == typeid(double)) { + std::ostringstream oss; + oss << std::any_cast(value); + return oss.str(); + } + if (t == typeid(std::vector)) + return VectorToString(std::any_cast&>(value)); + if (t == typeid(std::vector)) + return VectorToString(std::any_cast&>(value)); + if (t == typeid(std::vector)) + return VectorToString(std::any_cast&>(value)); + return ""; +} + +} // namespace + +std::string any_to_string(const MetadataVType& value) { + return std::visit([](const auto& v) -> std::string { + using T = std::decay_t; + if constexpr (std::is_same_v) return ScalarToString(v); + if constexpr (std::is_same_v) return ScalarToString(v); + if constexpr (std::is_same_v) return ScalarToString(v); + if constexpr (std::is_same_v>) return VectorToString(v); + if constexpr (std::is_same_v>) return VectorToString(v); + if constexpr (std::is_same_v>) return VectorToString(v); + if constexpr (std::is_same_v) return AnyToString(v); + return std::string(""); + }, value); +} + +} // namespace lazyllm diff --git a/csrc/include/doc_node.h b/csrc/include/doc_node.h deleted file mode 100644 index 4908d9941..000000000 --- a/csrc/include/doc_node.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -namespace lazyllm { - -class DocNode { -public: - DocNode() = default; - explicit DocNode(const std::string& text); - - void set_text(const std::string& text); - const std::string& get_text() const; - -private: - std::string _text; -}; - -} // namespace lazyllm diff --git a/csrc/scripts/build_debug.sh b/csrc/scripts/build_debug.sh index 08adee4c7..1bea62ee6 100644 --- a/csrc/scripts/build_debug.sh +++ b/csrc/scripts/build_debug.sh @@ -1,7 +1,11 @@ #!/usr/bin/env bash +# Run at LazyLLM/. set -euo pipefail cmake -S csrc -B build \ -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ -DCMAKE_BUILD_TYPE=Debug cmake --build build + +# Install into ./lazyllm (prefix=. + LIBRARY DESTINATION lazyllm). +cmake --install build --prefix . --component lazyllm_cpp diff --git a/csrc/scripts/build_release.sh b/csrc/scripts/build_release.sh index 865d22dc0..91660620a 100644 --- a/csrc/scripts/build_release.sh +++ b/csrc/scripts/build_release.sh @@ -1,7 +1,11 @@ #!/usr/bin/env bash +# Run at LazyLLM/. set -euo pipefail cmake -S csrc -B build-release \ -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ -DCMAKE_BUILD_TYPE=Release cmake --build build-release + +# Install into ./lazyllm (prefix=. + LIBRARY DESTINATION lazyllm). +cmake --install build --prefix . --component lazyllm_cpp diff --git a/csrc/scripts/build_test.sh b/csrc/scripts/build_test.sh index cd9336823..c6b14f9c3 100644 --- a/csrc/scripts/build_test.sh +++ b/csrc/scripts/build_test.sh @@ -6,4 +6,4 @@ cmake -S csrc -B build \ -DCMAKE_BUILD_TYPE=Debug \ -DBUILD_TESTS=ON cmake --build build -ctest --test-dir build +ctest --test-dir build --rerun-failed --output-on-failure diff --git a/csrc/scripts/config_cmake.sh b/csrc/scripts/config_cmake.sh new file mode 100644 index 000000000..a02b81209 --- /dev/null +++ b/csrc/scripts/config_cmake.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +set -euo pipefail + +cmake -S csrc -B build \ + -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ + -DCMAKE_BUILD_TYPE=Debug diff --git a/csrc/src/README.md b/csrc/src/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/csrc/src/doc_node.cpp b/csrc/src/doc_node.cpp deleted file mode 100644 index f774ec0e5..000000000 --- a/csrc/src/doc_node.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "doc_node.h" - -namespace lazyllm { - -DocNode::DocNode(const std::string& text) : _text(text) {} - -void DocNode::set_text(const std::string& text) { - _text = text; -} - -const std::string& DocNode::get_text() const { - return _text; -} - -} // namespace lazyllm diff --git a/csrc/tests/test_adaptor_base.cpp b/csrc/tests/test_adaptor_base.cpp new file mode 100644 index 000000000..3a7cc10e8 --- /dev/null +++ b/csrc/tests/test_adaptor_base.cpp @@ -0,0 +1,37 @@ +#include + +#include +#include +#include + +#include "adaptor_base.hpp" + +namespace { + +class MockAdaptor final : public lazyllm::AdaptorBase { +public: + mutable int call_count = 0; + + std::any call( + const std::string& func_name, + const std::unordered_map& args) const override + { + ++call_count; + if (func_name == "sum") { + return std::any_cast(args.at("left")) + std::any_cast(args.at("right")); + } + return func_name; + } +}; + +} // namespace + +TEST(adaptor_base, derived_call) { + MockAdaptor adaptor; + auto result = adaptor.call("echo_me", {}); + EXPECT_EQ(std::any_cast(result), "echo_me"); + EXPECT_EQ(adaptor.call_count, 1); + + result = adaptor.call("sum", {{"left", 3}, {"right", 4}}); + EXPECT_EQ(std::any_cast(result), 7); +} diff --git a/csrc/tests/test_doc_node.cpp b/csrc/tests/test_doc_node.cpp index ff3dead26..cc8a27d5a 100644 --- a/csrc/tests/test_doc_node.cpp +++ b/csrc/tests/test_doc_node.cpp @@ -1,16 +1,160 @@ #include -#include "doc_node.h" +#include +#include +#include +#include +#include -TEST(DocNode, DefaultEmpty) { - lazyllm::DocNode node; - EXPECT_EQ(node.get_text(), ""); -} +#include "adaptor_base.hpp" +#include "doc_node.hpp" +#include "utils.hpp" + +namespace { + +class MockChildrenAdaptor final : public lazyllm::AdaptorBase { +public: + mutable int call_count = 0; + lazyllm::DocNode::Children to_return; + + std::any call( + const std::string& func_name, + const std::unordered_map& args) const override + { + ++call_count; + EXPECT_EQ(func_name, "get_node_children"); + EXPECT_TRUE(args.find("node") != args.end()); + return to_return; + } +}; + +} // namespace -TEST(DocNode, SetGet) { - lazyllm::DocNode node("hello"); +TEST(doc_node, constructor) { + lazyllm::DocNode node("hello", "group", "fixed-uid"); + EXPECT_EQ(node.get_uid(), "fixed-uid"); + EXPECT_EQ(node._group_name, "group"); EXPECT_EQ(node.get_text(), "hello"); +} - node.set_text("world"); +TEST(doc_node, set_text_view_updates_text_hash) { + lazyllm::DocNode node("hello", "group", "fixed-uid"); + const size_t old_hash = node.get_text_hash(); + node.set_text_view("world"); + EXPECT_NE(node.get_text_hash(), old_hash); EXPECT_EQ(node.get_text(), "world"); } + +TEST(doc_node, metadata) { + lazyllm::DocNode node; + node._metadata = lazyllm::DocNode::Metadata{ + {"alpha", std::string("A")}, + {"beta", std::string("B")}, + }; + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::ALL), "alpha: A\nbeta: B"); + + node.set_excluded_embed_metadata_keys({"beta"}); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::EMBED), "alpha: A"); + + node.set_excluded_llm_metadata_keys({"alpha"}); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::LLM), "beta: B"); + + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::NONE), ""); + + lazyllm::DocNode root("root"); + lazyllm::DocNode child("child", "", "", &root); + + root.set_excluded_embed_metadata_keys({"root_embed"}); + child.set_excluded_embed_metadata_keys({"child_embed"}); + EXPECT_EQ( + child.get_excluded_embed_metadata_keys(), + (std::set{"child_embed", "root_embed"})); + + root.set_excluded_llm_metadata_keys({"root_llm"}); + child.set_excluded_llm_metadata_keys({"child_llm"}); + EXPECT_EQ( + child.get_excluded_llm_metadata_keys(), + (std::set{"child_llm", "root_llm"})); +} + +TEST(doc_node, text) { + lazyllm::DocNode node("body"); + node._metadata = lazyllm::DocNode::Metadata{{"alpha", std::string("A")}}; + EXPECT_EQ(node.get_text(lazyllm::MetadataMode::NONE), "body"); + EXPECT_EQ(node.get_text(lazyllm::MetadataMode::EMBED), "alpha: A\n\nbody"); +} + +TEST(doc_node, relationships) { + lazyllm::DocNode root("root", "root_group", "root_uid"); + lazyllm::DocNode child("child", "child_group", "child_uid", &root); + EXPECT_EQ(child.get_root_node(), &root); + EXPECT_EQ(child.get_parent_uid(), "root_uid"); + + root.set_children_group("split", {std::make_shared(std::move(child))}); + EXPECT_TRUE(root.is_children_group_exists("split")); + const auto child_ids = root.get_children_uid(); + ASSERT_TRUE(child_ids.find("split") != child_ids.end()); + ASSERT_EQ(child_ids.at("split").size(), 1u); + EXPECT_EQ(child_ids.at("split")[0], "child_uid"); +} + +TEST(doc_node, children_caching) { + lazyllm::DocNode parent("p", "", "parent"); + lazyllm::DocNode child("c", "", "child", &parent); + + auto adaptor = std::make_shared(); + adaptor->to_return["cached"] = {std::make_shared(std::move(child))}; + parent.set_store(adaptor); + + const auto first = parent.get_children(); + const auto second = parent.get_children(); + + EXPECT_EQ(adaptor->call_count, 1); + ASSERT_TRUE(first.find("cached") != first.end()); + ASSERT_EQ(first.at("cached").size(), 1u); + EXPECT_EQ(first.at("cached")[0], second.at("cached")[0]); +} + +TEST(doc_node, doc_path_reads_from_root_global_metadata) { + auto global_meta = std::make_shared(); + (*global_meta)[std::string(lazyllm::RAGMetadataKeys::DOC_PATH)] = std::string("/tmp/a.txt"); + lazyllm::DocNode root("root", "", "root", nullptr, {}, global_meta); + lazyllm::DocNode child("child", "", "child", &root, {}, global_meta); + + EXPECT_EQ(child.get_doc_path(), "/tmp/a.txt"); + + child.set_doc_path("/tmp/b.txt"); + EXPECT_EQ(root.get_doc_path(), "/tmp/b.txt"); +} + +TEST(doc_node, embedding_keys_undone_throws_on_empty_input) { + lazyllm::DocNode node; + EXPECT_THROW(node.embedding_keys_undone({}), std::runtime_error); + + node.set_embedding_vec("done", {1.0, 2.0}); + const auto missing = node.embedding_keys_undone({"done", "todo"}); + EXPECT_EQ(missing, (std::set{"todo"})); +} + +TEST(doc_node, py_do_embedding_writes_embedding_vector) { + lazyllm::DocNode node("text"); + + node.py_do_embedding({ + {"len_embedding", [](const std::string& input) { + return std::vector{static_cast(input.size())}; + }} + }); + + ASSERT_TRUE(node._embedding_vecs.find("len_embedding") != node._embedding_vecs.end()); + ASSERT_EQ(node._embedding_vecs["len_embedding"].size(), 1u); + EXPECT_EQ(node._embedding_vecs["len_embedding"][0], 4.0); // "text" +} + +TEST(doc_node, equality_uses_uid) { + lazyllm::DocNode lhs("left", "", "same_uid"); + lazyllm::DocNode rhs("right", "", "same_uid"); + EXPECT_TRUE(lhs == rhs); + + lazyllm::DocNode other("other", "", "other_uid"); + EXPECT_TRUE(lhs != other); +} diff --git a/csrc/tests/test_map_params.cpp b/csrc/tests/test_map_params.cpp new file mode 100644 index 000000000..10ecc5bae --- /dev/null +++ b/csrc/tests/test_map_params.cpp @@ -0,0 +1,60 @@ +#include + +#include +#include +#include + +#include "map_params.hpp" + +TEST(map_params, get_param_value) { + lazyllm::MapParams params; + params.set_default("chunk_size", 1024u); + + auto value = params.get_param_value("chunk_size", 4096u, 2048u); + EXPECT_EQ(value, 4096u); + + value = params.get_param_value("chunk_size", std::nullopt, 4u); + EXPECT_EQ(value, 1024u); + + EXPECT_EQ(params.get_param_value("missing", std::nullopt, 4u), 4u); +} + +TEST(map_params, bulk_set_updates_defaults) { + lazyllm::MapParams params; + lazyllm::MapParams::MapType updates{ + {"encoding", std::string("gpt2")}, + {"overlap", 128u}, + }; + params.set_default(updates); + EXPECT_EQ(std::any_cast(params.get_default().at("encoding")), "gpt2"); + EXPECT_EQ(std::any_cast(params.get_default().at("overlap")), 128u); +} + +TEST(map_params, typed_get_default_returns_value) { + lazyllm::MapParams params; + params.set_default("encoding", std::string("gpt2")); + params.set_default("overlap", 128u); + + const auto encoding = params.get_default("encoding"); + const auto overlap = params.get_default("overlap"); + + ASSERT_TRUE(encoding.has_value()); + ASSERT_TRUE(overlap.has_value()); + EXPECT_EQ(*encoding, "gpt2"); + EXPECT_EQ(*overlap, 128u); +} + +TEST(map_params, typed_get_default_returns_nullopt_when_missing) { + lazyllm::MapParams params; + const auto missing = params.get_default("chunk_size"); + EXPECT_FALSE(missing.has_value()); +} + +TEST(map_params, reset_default_clears_state) { + lazyllm::MapParams params; + params.set_default("k", 1); + ASSERT_FALSE(params.get_default().empty()); + + params.reset_default(); + EXPECT_TRUE(params.get_default().empty()); +} diff --git a/csrc/tests/test_node_transform.cpp b/csrc/tests/test_node_transform.cpp new file mode 100644 index 000000000..41833b0d7 --- /dev/null +++ b/csrc/tests/test_node_transform.cpp @@ -0,0 +1,83 @@ +#include + +#include "doc_node.hpp" +#include "node_transform.hpp" + +using namespace lazyllm; + +namespace { + +class PairTransform final : public NodeTransform { +public: + explicit PairTransform(int worker_num = 0) : NodeTransform(worker_num) {} + + mutable int transform_calls = 0; + + std::vector transform(PDocNode node) const override { + ++transform_calls; + const auto& shared_meta = node->get_root_node()->_p_global_metadata; + return { + std::make_shared(std::string(node->get_text_view()) + "_A", shared_meta), + std::make_shared(std::string(node->get_text_view()) + "_B", shared_meta) + }; + } +}; + +class SingleTransform final : public NodeTransform { +public: + explicit SingleTransform(int worker_num) : NodeTransform(worker_num) {} + + std::vector transform(PDocNode node) const override { + const auto shared_meta = node->get_root_node()->_p_global_metadata; + return { + std::make_shared(std::string(node->get_text_view()) + "_child", shared_meta) + }; + } +}; + +std::vector make_roots(size_t count) { + std::vector roots; + roots.reserve(count); + for (size_t i = 0; i < count; ++i) + roots.push_back(std::make_shared("n" + std::to_string(i), "", "uid_" + std::to_string(i))); + return roots; +} + +} // namespace + +TEST(node_transform, call_operator_returns_transform_result) { + auto roots = make_roots(1); + PairTransform transform; + const auto direct = transform(roots[0]); + EXPECT_EQ(direct.size(), 2u); +} + +TEST(node_transform, batch_forward) { + auto roots = make_roots(2); + PairTransform transform; + const auto children = transform.batch_forward(roots, "split"); + + EXPECT_EQ(children.size(), 4u); + for (auto child : children) { + EXPECT_NE(child->get_parent_node(), nullptr); + EXPECT_EQ(child->_group_name, "split"); + } + + for (auto root : roots) { + EXPECT_TRUE(root->is_children_group_exists("split")); + EXPECT_EQ(root->get_children().at("split").size(), 2u); + } +} + +TEST(node_transform, batch_forward_parallel_mode_respects_worker_num) { + auto roots = make_roots(8); + SingleTransform transform(3); + EXPECT_EQ(transform.worker_num(), 3); + + const auto children = transform.batch_forward(roots, "parallel"); + EXPECT_EQ(children.size(), roots.size()); + for (auto root : roots) { + EXPECT_TRUE(root->is_children_group_exists("parallel")); + EXPECT_EQ(root->get_children().at("parallel").size(), 1u); + } +} diff --git a/csrc/tests/test_sentence_splitter.cpp b/csrc/tests/test_sentence_splitter.cpp new file mode 100644 index 000000000..150002006 --- /dev/null +++ b/csrc/tests/test_sentence_splitter.cpp @@ -0,0 +1,56 @@ +#include + +#include +#include + +#include "sentence_splitter.hpp" +#include "utils.hpp" + +namespace { + +class TestSentenceSplitter final : public lazyllm::SentenceSplitter { +public: + TestSentenceSplitter(unsigned chunk_size, unsigned overlap) + : lazyllm::SentenceSplitter(chunk_size, overlap, 0) {} + + using lazyllm::SentenceSplitter::merge_chunks; +}; + +} // namespace + +TEST(sentence_splitter, merge_chunks_applies_overlap) { + TestSentenceSplitter splitter(5, 2); + + const std::vector splits{ + {"ab", false, 2}, + {"cd", false, 2}, + {"ef", false, 2}, + }; + + const auto merged = splitter.merge_chunks(splits, 5); + EXPECT_EQ(merged, (std::vector{"abcd", "cdef"})); +} + +TEST(sentence_splitter, merge_chunks_throws_on_oversized_single_split) { + TestSentenceSplitter splitter(3, 1); + const std::vector splits{ + {"abcd", false, 4}, + }; + + EXPECT_THROW((void)splitter.merge_chunks(splits, 3), std::runtime_error); +} + +TEST(sentence_splitter, merge_chunks_shrinks_overlap_to_fit_next_chunk) { + TestSentenceSplitter splitter(5, 4); + + const std::vector splits{ + {"aa", false, 2}, + {"b", false, 1}, + {"cccc", false, 4}, + {"dd", false, 2}, + {"ee", false, 2}, + }; + + const auto merged = splitter.merge_chunks(splits, 5); + EXPECT_EQ(merged, (std::vector{"aab", "bcccc", "ddee"})); +} diff --git a/csrc/tests/test_smoke.cpp b/csrc/tests/test_smoke.cpp index 4c291b34b..998974c44 100644 --- a/csrc/tests/test_smoke.cpp +++ b/csrc/tests/test_smoke.cpp @@ -2,6 +2,6 @@ #include "lazyllm.hpp" -TEST(LazyLLM, Smoke) { +TEST(lazyllm, smoke) { EXPECT_GT(PYBIND11_VERSION_MAJOR, 0); } diff --git a/csrc/tests/test_text_splitter_base.cpp b/csrc/tests/test_text_splitter_base.cpp new file mode 100644 index 000000000..0b0fbdeb7 --- /dev/null +++ b/csrc/tests/test_text_splitter_base.cpp @@ -0,0 +1,119 @@ +#include + +#include +#include +#include +#include + +#include "doc_node.hpp" +#include "text_splitter_base.hpp" +#include "utils.hpp" + +namespace { + +class ByteTokenizer final : public Tokenizer { +public: + std::vector encode(const std::string_view& view) const override { + std::vector out; + out.reserve(view.size()); + for (unsigned char ch : view) { + out.push_back(static_cast(ch)); + } + return out; + } + + std::string decode(const std::vector& token_ids) const override { + std::string out; + out.reserve(token_ids.size()); + for (int token_id : token_ids) { + out.push_back(static_cast(token_id)); + } + return out; + } +}; + +class TestTextSplitter final : public lazyllm::TextSplitterBase { +public: + TestTextSplitter(unsigned chunk_size, unsigned overlap = 0) + : lazyllm::TextSplitterBase(chunk_size, overlap, 0) {} + + using lazyllm::TextSplitterBase::merge_chunks; + using lazyllm::TextSplitterBase::split_recursive; +}; + +} // namespace + +TEST(text_splitter_base, exception_management) { + EXPECT_THROW((void)lazyllm::TextSplitterBase(10, 11), std::runtime_error); + EXPECT_THROW((void)lazyllm::TextSplitterBase(0, 0), std::runtime_error); +} + +TEST(text_splitter_base, split_text_keep_separator_returns_segments) { + const auto parts = lazyllm::TextSplitterBase::split_text_while_keeping_separator("a--b--", "--"); + ASSERT_EQ(parts.size(), 2u); + EXPECT_EQ(parts[0], "a--"); + EXPECT_EQ(parts[1], "b--"); +} + +TEST(text_splitter_base, split_text_keep_separator_skips_leading_separator) { + const auto parts = lazyllm::TextSplitterBase::split_text_while_keeping_separator("--x", "--"); + ASSERT_EQ(parts.size(), 1u); + EXPECT_EQ(parts[0], "x"); +} + +TEST(text_splitter_base, split_text_throws_when_metadata_exceeds_chunk_size) { + lazyllm::TextSplitterBase splitter(60, 0); + splitter.set_tokenizer(std::make_shared()); + EXPECT_THROW((void)splitter.split_text("abc", 60), std::runtime_error); +} + +TEST(text_splitter_base, split_text_throws_when_metadata_budget_too_small) { + lazyllm::TextSplitterBase splitter(60, 0); + splitter.set_tokenizer(std::make_shared()); + EXPECT_THROW((void)splitter.split_text("abc", 11), std::runtime_error); +} + +TEST(text_splitter_base, split_recursive_falls_back_to_char_level) { + TestTextSplitter splitter(100, 0); + splitter.set_tokenizer(std::make_shared()); + + const auto chunks = splitter.split_recursive("abc", 2); + ASSERT_EQ(chunks.size(), 3u); + EXPECT_EQ(chunks[0].view, "a"); + EXPECT_EQ(chunks[1].view, "b"); + EXPECT_EQ(chunks[2].view, "c"); + EXPECT_FALSE(chunks[0].is_sentence); +} + +TEST(text_splitter_base, merge_chunks_uses_overlap) { + TestTextSplitter splitter(100, 1); + splitter.set_tokenizer(std::make_shared()); + + const std::vector splits{ + {"ab", true, 2}, + {"cd", true, 2}, + {"ef", true, 2}, + }; + + const auto merged = splitter.merge_chunks(splits, 4); + EXPECT_EQ(merged, (std::vector{"ab", "bcd", "def"})); +} + +TEST(text_splitter_base, transform_returns_chunk_nodes) { + lazyllm::TextSplitterBase splitter(100, 0); + splitter.set_tokenizer(std::make_shared()); + + lazyllm::PDocNode node = std::make_shared("hello"); + + auto chunks = splitter.transform(node); + ASSERT_EQ(chunks.size(), 1u); + EXPECT_EQ(chunks[0]->get_text(), "hello"); + + chunks = splitter.transform(nullptr); + EXPECT_TRUE(chunks.empty()); +} + +TEST(text_splitter_base, from_tiktoken_encoder_throws_on_invalid_name) { + lazyllm::TextSplitterBase splitter(100, 0); + EXPECT_THROW((void)splitter.from_tiktoken_encoder("definitely_unknown"), std::runtime_error); +} diff --git a/csrc/tests/test_thread_pool.cpp b/csrc/tests/test_thread_pool.cpp new file mode 100644 index 000000000..b834fe63c --- /dev/null +++ b/csrc/tests/test_thread_pool.cpp @@ -0,0 +1,28 @@ +#include + +#include +#include + +#include "thread_pool.hpp" + +TEST(thread_pool, executes_tasks) { + ThreadPool pool(3); + + auto f1 = pool.enqueue([] { return 1 + 2; }); + EXPECT_EQ(f1.get(), 3); +} + +TEST(thread_pool, returns_values_from_futures) { + ThreadPool pool(3); + auto f2 = pool.enqueue([](int v) { return v * 2; }, 5); + EXPECT_EQ(f2.get(), 10); +} + +TEST(thread_pool, propagates_task_exception_through_future) { + ThreadPool pool(1); + auto failing = pool.enqueue([]() -> int { + throw std::runtime_error("boom"); + }); + + EXPECT_THROW((void)failing.get(), std::runtime_error); +} diff --git a/csrc/tests/test_tokenizer.cpp b/csrc/tests/test_tokenizer.cpp new file mode 100644 index 000000000..346aa4e39 --- /dev/null +++ b/csrc/tests/test_tokenizer.cpp @@ -0,0 +1,56 @@ +#include + +#include +#include +#include + +#include "tokenizer.hpp" + +namespace { + +class IdentityTokenizer final : public Tokenizer { +public: + std::vector encode(const std::string_view& view) const override { + std::vector out; + out.reserve(view.size()); + for (unsigned char ch : view) out.push_back(static_cast(ch)); + return out; + } + + std::string decode(const std::vector& token_ids) const override { + std::string out; + out.reserve(token_ids.size()); + for (int id : token_ids) out.push_back(static_cast(id)); + return out; + } +}; + +} // namespace + +TEST(tokenizer, abstract_interface_via_derived_class) { + std::unique_ptr tokenizer = std::make_unique(); + const auto ids = tokenizer->encode("abc"); + + EXPECT_EQ(ids, (std::vector{97, 98, 99})); + EXPECT_EQ(tokenizer->decode(ids), "abc"); +} + +TEST(tiktoken_tokenizer, round_trip_encoding) { + TiktokenTokenizer tokenizer("gpt2"); + const std::string text = "hello tokenizer"; + + const auto token_ids = tokenizer.encode(text); + EXPECT_FALSE(token_ids.empty()); + EXPECT_EQ(tokenizer.decode(token_ids), text); +} + +TEST(tiktoken_tokenizer, alias_names_map_to_same_encoding) { + TiktokenTokenizer gpt2("gpt2"); + TiktokenTokenizer r50k("r50k_base"); + + EXPECT_EQ(gpt2.encode("same input"), r50k.encode("same input")); +} + +TEST(tiktoken_tokenizer, unknown_encoding_throws) { + EXPECT_THROW((void)TiktokenTokenizer("unknown_model"), std::runtime_error); +} diff --git a/csrc/tests/test_unicode_processor.cpp b/csrc/tests/test_unicode_processor.cpp new file mode 100644 index 000000000..ee86d29e9 --- /dev/null +++ b/csrc/tests/test_unicode_processor.cpp @@ -0,0 +1,30 @@ +#include + +#include +#include + +#include "unicode_processor.hpp" + +TEST(unicode_processor, split_to_chars_supports_multibyte) { + const std::string text = "a你🙂"; + const lazyllm::UnicodeProcessor processor(text); + + const auto chars = processor.split_to_chars(); + EXPECT_EQ(chars, (std::vector{"a", "你", "🙂"})); +} + +TEST(unicode_processor, split_by_punctuation_for_ascii) { + const std::string text = "Hello,world!"; + const lazyllm::UnicodeProcessor processor(text); + + const auto chunks = processor.split_by_punctuation(); + EXPECT_EQ(chunks, (std::vector{"Hello,", "world!"})); +} + +TEST(unicode_processor, split_by_punctuation_for_cjk) { + const std::string text = "你好。世界!"; + const lazyllm::UnicodeProcessor processor(text); + + const auto chunks = processor.split_by_punctuation(); + EXPECT_EQ(chunks, (std::vector{"你好。", "世界!"})); +} diff --git a/csrc/tests/test_utils.cpp b/csrc/tests/test_utils.cpp new file mode 100644 index 000000000..9b0850865 --- /dev/null +++ b/csrc/tests/test_utils.cpp @@ -0,0 +1,94 @@ +#include + +#include +#include +#include +#include + +#include "utils.hpp" + +TEST(utils, join_lines_returns_empty_for_empty_input) { + EXPECT_EQ(lazyllm::JoinLines({}), ""); +} + +TEST(utils, join_lines_uses_newline_separator_by_default) { + EXPECT_EQ(lazyllm::JoinLines({"a", "b", "c"}), "a\nb\nc"); +} + +TEST(utils, join_lines_supports_custom_delimiter) { + EXPECT_EQ(lazyllm::JoinLines({"a", "b", "c"}, ','), "a,b,c"); +} + +TEST(utils, concat_vector_appends_right_sequence) { + const auto merged = lazyllm::ConcatVector(std::vector{1, 2}, std::vector{3, 4}); + EXPECT_EQ(merged, (std::vector{1, 2, 3, 4})); +} + +TEST(utils, set_union_returns_all_unique_values) { + const std::set left{1, 2, 3}; + const std::set right{3, 4}; + EXPECT_EQ(lazyllm::SetUnion(left, right), (std::set{1, 2, 3, 4})); +} + +TEST(utils, set_diff_returns_only_left_unique_values) { + const std::set left{1, 2, 3}; + const std::set right{3, 4}; + EXPECT_EQ(lazyllm::SetDiff(left, right), (std::set{1, 2})); +} + +TEST(utils, to_hex_returns_lowercase_hex_text) { + EXPECT_EQ(lazyllm::to_hex(255u), "ff"); +} + +TEST(utils, generate_uuid_matches_expected_pattern) { + const std::string uuid = lazyllm::GenerateUUID(); + const std::regex pattern("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"); + EXPECT_TRUE(std::regex_match(uuid, pattern)); +} + +TEST(utils, is_adjacent_returns_true_for_contiguous_views) { + const std::string text = "abcdef"; + const std::string_view left = std::string_view(text.data(), 3); + const std::string_view right = std::string_view(text.data() + 3, 3); + EXPECT_TRUE(lazyllm::is_adjacent(left, right)); +} + +TEST(utils, is_adjacent_returns_false_for_non_contiguous_views) { + const std::string text = "abcdef"; + const std::string_view left = std::string_view(text.data(), 3); + const std::string_view right = std::string_view(text.data() + 4, 2); + EXPECT_FALSE(lazyllm::is_adjacent(left, right)); +} + +TEST(utils, chunk_operator_plus_equals_accumulates_fields) { + lazyllm::Chunk l{"ab", true, 2}; + lazyllm::Chunk r{"cd", false, 3}; + + l += r; + EXPECT_EQ(l.text, "abcd"); + EXPECT_FALSE(l.is_sentence); + EXPECT_EQ(l.token_size, 5); +} + +TEST(utils, rag_metadata_keys_constants_are_exposed) { + EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_PATH, "lazyllm_doc_path"); + EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_ID, "docid"); +} + +TEST(utils, any_to_string_formats_scalar_metadata_values) { + EXPECT_EQ(lazyllm::any_to_string(lazyllm::MetadataVType(std::string("alpha"))), "alpha"); + EXPECT_EQ(lazyllm::any_to_string(lazyllm::MetadataVType(7)), "7"); + EXPECT_EQ(lazyllm::any_to_string(lazyllm::MetadataVType(3.5)), "3.5"); +} + +TEST(utils, any_to_string_formats_vector_metadata_values_with_brackets) { + EXPECT_EQ( + lazyllm::any_to_string(lazyllm::MetadataVType(std::vector{"a", "b"})), + "[a,b]"); + EXPECT_EQ( + lazyllm::any_to_string(lazyllm::MetadataVType(std::vector{1, 2, 3})), + "[1,2,3]"); + EXPECT_EQ( + lazyllm::any_to_string(lazyllm::MetadataVType(std::vector{1.5, 2.0})), + "[1.5,2]"); +} diff --git a/lazyllm/cpp.py b/lazyllm/cpp.py index 66e213d70..74ae455a2 100644 --- a/lazyllm/cpp.py +++ b/lazyllm/cpp.py @@ -1,4 +1,37 @@ -try: - from .lazyllm_cpp import * # noqa F403 -except ImportError: - pass +import importlib +import os +from typing import TypeVar, cast + +_LAZYLLM_CPP_MODULE = None +_LAZYLLM_CPP_ENABLED = None +_C = TypeVar('_C', bound=type) + + +def _is_enabled() -> bool: + global _LAZYLLM_CPP_ENABLED + if _LAZYLLM_CPP_ENABLED is None: + value = os.getenv('LAZYLLM_ENABLE_CPP_OVERRIDE') + _LAZYLLM_CPP_ENABLED = value is not None and (value == '1' or value.lower() == 'true') + return _LAZYLLM_CPP_ENABLED + + +def _load_cpp_module(): + global _LAZYLLM_CPP_MODULE + if _LAZYLLM_CPP_MODULE is None: + _LAZYLLM_CPP_MODULE = importlib.import_module('lazyllm.lazyllm_cpp') + return _LAZYLLM_CPP_MODULE + + +def cpp_class(py_class: _C) -> _C: + if not isinstance(py_class, type): + raise TypeError(f'@cpp_class can only decorate classes, got: {type(py_class).__name__}') + + if not _is_enabled(): + return py_class + + cpp_module = _load_cpp_module() + export_name = py_class.__name__ + + cpp_export = getattr(cpp_module, export_name) + + return cast(_C, cpp_export) diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 79fda5cae..b73c3b4d4 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -3,6 +3,7 @@ from collections import defaultdict from lazyllm.thirdparty import PIL from lazyllm import JsonFormatter, config, reset_on_pickle, Mode, LOG +from lazyllm.cpp import cpp_class from lazyllm.components.utils.file_operate import _image_to_base64 from .global_metadata import RAG_DOC_ID, RAG_DOC_PATH, RAG_KB_ID import uuid @@ -22,6 +23,7 @@ class MetadataMode(str, Enum): NONE = auto() +@cpp_class @reset_on_pickle(('_lock', threading.Lock)) class DocNode: def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[Any]]] = None, @@ -51,7 +53,7 @@ def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[ self._store = store self._node_groups: Dict[str, Dict] = node_groups or {} self._lock = threading.Lock() - self._embedding_state = set() + self.embedding_state = set() self.relevance_score = None self.similarity_score = None self._content_hash: Optional[str] = None @@ -256,7 +258,7 @@ def check_embedding_state(self, embed_key: str) -> None: while True: with self._lock: if not self.has_missing_embedding(embed_key): - self._embedding_state.discard(embed_key) + self.embedding_state.discard(embed_key) break time.sleep(1) diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index 36f59730f..42228592d 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -16,11 +16,13 @@ import threading from lazyllm.thirdparty import tiktoken from lazyllm import config, ModuleBase +from lazyllm.cpp import cpp_class from pathlib import Path import inspect from lazyllm.thirdparty import nltk from lazyllm.thirdparty import transformers +@cpp_class class MetadataMode(str, Enum): ALL = 'ALL' EMBED = 'EMBED' @@ -61,6 +63,7 @@ def split_text_keep_separator(text: str, separator: str) -> List[str]: return result +@cpp_class class NodeTransform(ModuleBase): __support_rich__ = False @@ -206,6 +209,7 @@ def _forward_single(self, node: Union[DocNode, RichDocNode], **kwargs: Any) -> L _UNSET = object() +@cpp_class class _TextSplitterBase(NodeTransform): _default_params = {} _default_params_lock = threading.RLock() @@ -639,3 +643,105 @@ def __len__(self) -> int: def __bool__(self) -> bool: return bool(self._rules) + + +@dataclass(frozen=True) +class Rule: + name: str + match: Callable[[Any], Any] + apply: Callable[..., Any] + priority: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def __call__(self, data: Any) -> Optional[Any]: + match_result = self.match(data) + if not match_result: + return None + try: + sig = inspect.signature(self.apply) + params = [ + p for p in sig.parameters.values() + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + ] + if len(params) >= 3: + return self.apply(data, match_result, self) + except (ValueError, TypeError): + pass + return self.apply(data, self) + + @staticmethod + def build(name: str, rule: Union[str, Callable[[Any], bool]], + apply: Callable[[Any, 'Rule'], Any], priority: int = 0) -> 'Rule': + if isinstance(rule, str): + compiled = re.compile(rule) + return Rule( + name=name, + match=lambda text: compiled.search(text), + apply=lambda text, match_result, r: apply(match_result, text), + priority=priority, + ) + if callable(rule): + return Rule( + name=name, + match=lambda data: True if rule(data) else None, + apply=lambda data, _match_result, r: apply(data, r), + priority=priority, + ) + raise TypeError('rule must be a pattern string or a predicate callable') + +@dataclass +class _Context: + total: int + current_idx: int = 0 + prev_node: Optional[Any] = None + prev_result: Optional[Any] = None + user_data: Dict[str, Any] = field(default_factory=dict) + +class RuleSet: + def __init__(self, rules: Optional[List[Rule]] = None): + self._rules: List[Rule] = [] + if rules: + self.extend(rules) + + def add(self, *rules: Rule) -> 'RuleSet': + self._rules.extend(rules) + self._sort() + return self + + def extend(self, rules: List[Rule]) -> 'RuleSet': + self._rules.extend(rules) + self._sort() + return self + + def _sort(self): + self._rules.sort(key=lambda r: r.priority, reverse=True) + + def first(self, data: Any) -> Optional[tuple[Rule, Any]]: + for rule in self._rules: + result = rule(data) + if result is not None: + return (rule, result) + return None + + def all(self, data: Any) -> List[tuple[Rule, Any]]: + results = [] + for rule in self._rules: + result = rule(data) + if result is not None: + results.append((rule, result)) + return results + + def filter(self, predicate: Callable[[Rule], bool]) -> 'RuleSet': + return RuleSet([r for r in self._rules if predicate(r)]) + + def __iter__(self) -> Iterator[Rule]: + return iter(self._rules) + + def __len__(self) -> int: + return len(self._rules) + + def __bool__(self) -> bool: + return bool(self._rules) diff --git a/lazyllm/tools/rag/transform/sentence.py b/lazyllm/tools/rag/transform/sentence.py index 2f1729323..086083396 100644 --- a/lazyllm/tools/rag/transform/sentence.py +++ b/lazyllm/tools/rag/transform/sentence.py @@ -1,6 +1,8 @@ from typing import List, Tuple from .base import _TextSplitterBase, _Split, _UNSET +from lazyllm.cpp import cpp_class +@cpp_class class SentenceSplitter(_TextSplitterBase): def __init__(self, chunk_size: int = _UNSET, chunk_overlap: int = _UNSET, num_workers: int = _UNSET): super().__init__(chunk_size=chunk_size, overlap=chunk_overlap, num_workers=num_workers) diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index c6b0eebe9..a02cfccd0 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -871,8 +871,7 @@ def parallel_do_embedding(embed: Dict[str, Callable], embed_keys: Optional[Union modified_nodes.append(node) for k in miss: tasks_by_key[k].append(node) - if hasattr(node, '_embedding_state'): - node._embedding_state.add(k) + node.embedding_state.add(k) if not tasks_by_key: return [] @@ -907,9 +906,9 @@ def _process_key(k: str, knodes: List[DocNode]): except Exception as e: lazyllm.LOG.error(f'[LazyLLM - parallel_do_embedding][{k}] error: {e}') for n in knodes: - if hasattr(n, '_embedding_state') and k in n._embedding_state: + if k in n.embedding_state: with n._lock: - n._embedding_state.remove(k) + n.embedding_state.remove(k) raise e with ThreadPoolExecutor(max_workers=min(max_workers, len(tasks_by_key))) as ex: diff --git a/tests/cpp_ext_tests/test_doc_node.py b/tests/cpp_ext_tests/test_doc_node.py deleted file mode 100644 index 40991bb87..000000000 --- a/tests/cpp_ext_tests/test_doc_node.py +++ /dev/null @@ -1,13 +0,0 @@ -class TestDocNode: - def setup_method(self): - from lazyllm import lazyllm_cpp - self.lazyllm_cpp = lazyllm_cpp - - def test_doc_node_set_get(self): - node = self.lazyllm_cpp.DocNode() - assert node.get_text() == '' - node.set_text('hello') - assert node.get_text() == 'hello' - - node2 = self.lazyllm_cpp.DocNode('world') - assert node2.get_text() == 'world' diff --git a/tests/test_cpp_class_decorator.py b/tests/test_cpp_class_decorator.py new file mode 100644 index 000000000..dfc913e35 --- /dev/null +++ b/tests/test_cpp_class_decorator.py @@ -0,0 +1,71 @@ +import importlib +import pytest +from types import SimpleNamespace + + +def _reload_cpp_module(): + import lazyllm.cpp as cpp + return importlib.reload(cpp) + + +def test_cpp_class_keeps_python_class_when_disabled(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '0') + cpp = _reload_cpp_module() + + class PyOnly: + pass + + replaced = cpp.cpp_class(PyOnly) + assert replaced is PyOnly + + +def test_cpp_class_replaces_with_cpp_export_when_enabled(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + + class CppDummy: + pass + + monkeypatch.setattr(cpp, '_load_cpp_module', lambda: SimpleNamespace(Dummy=CppDummy)) + + class Dummy: + pass + + replaced = cpp.cpp_class(Dummy) + assert replaced is CppDummy + + +def test_cpp_class_rejects_non_class_object(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + + with pytest.raises(TypeError, match='can only decorate classes'): + cpp.cpp_class('NotAClass') + + +def test_cpp_class_raises_when_cpp_export_missing(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + monkeypatch.setattr(cpp, '_load_cpp_module', lambda: SimpleNamespace()) + + class Missing: + pass + + with pytest.raises(AttributeError, match="has no attribute 'Missing'"): + cpp.cpp_class(Missing) + + +def test_cpp_class_propagates_import_error_when_enabled(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + + def _boom(): + raise ImportError('boom') + + monkeypatch.setattr(cpp, '_load_cpp_module', _boom) + + class AnyClass: + pass + + with pytest.raises(ImportError, match='boom'): + cpp.cpp_class(AnyClass)