From e011555e2bfaf086d49df936fa4d0916afc93833 Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 20 Jan 2026 19:02:30 +0800 Subject: [PATCH 01/40] workflow fix --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f0e41b730..16355f400 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -672,6 +672,7 @@ jobs: cpp_ext_test: name: C++ Extension Test (${{ matrix.os }}) needs: [ clone ] + if: always() runs-on: ${{ matrix.os }} strategy: fail-fast: false From bfb16fa24e338c644bf46696d5daf7187e4af27a Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 21 Jan 2026 10:18:45 +0800 Subject: [PATCH 02/40] coarce migration --- csrc/include/doc_node.h | 149 ++++++++- csrc/src/doc_node.cpp | 721 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 865 insertions(+), 5 deletions(-) diff --git a/csrc/include/doc_node.h b/csrc/include/doc_node.h index 4908d9941..cb819d538 100644 --- a/csrc/include/doc_node.h +++ b/csrc/include/doc_node.h @@ -1,19 +1,164 @@ #pragma once +#include +#include +#include +#include #include +#include +#include namespace lazyllm { +constexpr const char kRagKbId[] = "kb_id"; +constexpr const char kRagDocId[] = "docid"; +constexpr const char kRagDocPath[] = "lazyllm_doc_path"; + +enum class MetadataMode { + All, + Embed, + Llm, + None, +}; + class DocNode { public: - DocNode() = default; + using Embedding = std::unordered_map>; + using Metadata = std::unordered_map; + using Children = std::unordered_map>; + using EmbeddingFn = std::function(const std::string&, const std::string&)>; + + DocNode(); explicit DocNode(const std::string& text); + DocNode(const DocNode& other); + DocNode& operator=(const DocNode& other); + virtual ~DocNode() = default; + + const std::string& uid() const; + const std::string& group() const; + void set_group(const std::string& group); + + bool content_is_list() const; + const std::vector& content_list() const; + const std::string& content_text() const; + + void set_content(const std::string& text); + void set_content(const std::vector& lines); void set_text(const std::string& text); const std::string& get_text() const; + virtual std::string get_text_with_metadata(MetadataMode mode) const; + + std::string content_hash() const; + + Embedding& embedding(); + const Embedding& embedding() const; + void set_embedding(const Embedding& embed); + + std::vector has_missing_embedding(const std::vector& embed_keys) const; + virtual void do_embedding(const std::unordered_map& embed); + void set_embedding_value(const std::string& key, const std::vector& value); + void check_embedding_state(const std::string& embed_key) const; + + DocNode* parent(); + const DocNode* parent() const; + void set_parent(DocNode* parent); + + Children& children(); + const Children& children() const; + void set_children(const Children& children); + + DocNode* root_node(); + const DocNode* root_node() const; + bool is_root_node() const; + + Metadata& metadata(); + const Metadata& metadata() const; + void set_metadata(const Metadata& metadata); + + Metadata& global_metadata(); + const Metadata& global_metadata() const; + void set_global_metadata(const Metadata& global_metadata); + + std::vector excluded_embed_metadata_keys() const; + void set_excluded_embed_metadata_keys(const std::vector& keys); + std::vector excluded_llm_metadata_keys() const; + void set_excluded_llm_metadata_keys(const std::vector& keys); + + std::string docpath() const; + void set_docpath(const std::string& path); + + std::string get_children_str() const; + std::string get_parent_id() const; + + std::string to_string() const; + bool operator==(const DocNode& other) const; + bool operator!=(const DocNode& other) const; + std::size_t hash() const; + + std::string get_metadata_str(MetadataMode mode = MetadataMode::All) const; + virtual std::string get_content(MetadataMode mode = MetadataMode::Llm) const; + + DocNode with_score(double score) const; + DocNode with_sim_score(double score) const; + + bool has_relevance_score() const; + bool has_similarity_score() const; + double relevance_score() const; + double similarity_score() const; + +protected: + void invalidate_content_hash(); + + std::string uid_; + std::string group_; + std::string text_; + bool content_is_list_; + std::vector content_list_; + Embedding embedding_; + Metadata metadata_; + Metadata global_metadata_; + std::vector excluded_embed_metadata_keys_; + std::vector excluded_llm_metadata_keys_; + DocNode* parent_; + Children children_; + bool children_loaded_; + mutable std::mutex embedding_mutex_; + mutable std::set embedding_state_; + mutable std::string content_hash_; + mutable bool content_hash_dirty_; + double relevance_score_; + bool has_relevance_score_; + double similarity_score_; + bool has_similarity_score_; +}; + +class QADocNode : public DocNode { +public: + QADocNode(const std::string& query, const std::string& answer); + QADocNode(const std::string& query, const std::string& answer, const std::string& uid, + const std::string& group = std::string()); + const std::string& answer() const; + std::string get_text_with_metadata(MetadataMode mode) const override; + +private: + std::string answer_; +}; + +class ImageDocNode : public DocNode { +public: + ImageDocNode(const std::string& image_path); + ImageDocNode(const std::string& image_path, const std::string& uid, + const std::string& group = std::string()); + + const std::string& image_path() const; + std::string get_content(MetadataMode mode = MetadataMode::Llm) const override; + void do_embedding(const std::unordered_map& embed) override; + std::string get_text_with_metadata(MetadataMode mode) const override; private: - std::string _text; + std::string image_path_; + std::string modality_; }; } // namespace lazyllm diff --git a/csrc/src/doc_node.cpp b/csrc/src/doc_node.cpp index f774ec0e5..6baf8107a 100644 --- a/csrc/src/doc_node.cpp +++ b/csrc/src/doc_node.cpp @@ -1,15 +1,730 @@ #include "doc_node.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + namespace lazyllm { +namespace { + +std::string JoinLines(const std::vector& lines) { + std::ostringstream oss; + for (size_t i = 0; i < lines.size(); ++i) { + if (i > 0) { + oss << "\n"; + } + oss << lines[i]; + } + return oss.str(); +} + +std::string Trim(const std::string& value) { + size_t start = 0; + while (start < value.size() && std::isspace(static_cast(value[start]))) { + ++start; + } + size_t end = value.size(); + while (end > start && std::isspace(static_cast(value[end - 1]))) { + --end; + } + return value.substr(start, end - start); +} + +std::string GenerateUUID() { + static const char kHex[] = "0123456789abcdef"; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist(0, 15); + const int groups[] = {8, 4, 4, 4, 12}; + std::string out; + out.reserve(36); + for (size_t group = 0; group < 5; ++group) { + if (group > 0) { + out.push_back('-'); + } + for (int i = 0; i < groups[group]; ++i) { + out.push_back(kHex[dist(gen)]); + } + } + return out; +} + +std::string ToLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return value; +} + +std::string GetExtension(const std::string& path) { + const size_t dot = path.find_last_of('.'); + if (dot == std::string::npos || dot + 1 >= path.size()) { + return ""; + } + return ToLower(path.substr(dot + 1)); +} + +std::string ImageMimeType(const std::string& path) { + const std::string ext = GetExtension(path); + if (ext == "jpg" || ext == "jpeg" || ext == "jfif" || ext == "jpe") { + return "image/jpeg"; + } + if (ext == "png" || ext == "apng") { + return "image/png"; + } + if (ext == "gif") { + return "image/gif"; + } + if (ext == "bmp" || ext == "dib") { + return "image/bmp"; + } + if (ext == "tif" || ext == "tiff") { + return "image/tiff"; + } + if (ext == "webp") { + return "image/webp"; + } + if (ext == "ico") { + return "image/x-icon"; + } + if (ext == "icns") { + return "image/icns"; + } + return ""; +} + +bool ReadFileBinary(const std::string& path, std::string* out) { + std::ifstream file(path.c_str(), std::ios::binary); + if (!file) { + return false; + } + std::ostringstream buffer; + buffer << file.rdbuf(); + *out = buffer.str(); + return true; +} + +std::string Base64Encode(const std::string& data) { + static const char kBase64Chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string out; + out.reserve(((data.size() + 2) / 3) * 4); + int val = 0; + int valb = -6; + for (unsigned char c : data) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(kBase64Chars[(val >> valb) & 0x3F]); + valb -= 6; + } + } + if (valb > -6) { + out.push_back(kBase64Chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + while (out.size() % 4) { + out.push_back('='); + } + return out; +} + +uint32_t RotateRight(uint32_t value, uint32_t bits) { + return (value >> bits) | (value << (32 - bits)); +} + +std::string Sha256Hex(const std::string& input) { + static const uint32_t k[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, + 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, + 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, + 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, + 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, + 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, + 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, + 0xc67178f2, + }; + uint32_t h[8] = { + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, + }; + + std::vector msg(input.begin(), input.end()); + const uint64_t bit_len = static_cast(msg.size()) * 8; + msg.push_back(0x80); + while ((msg.size() % 64) != 56) { + msg.push_back(0x00); + } + for (int i = 7; i >= 0; --i) { + msg.push_back(static_cast((bit_len >> (i * 8)) & 0xFF)); + } + + for (size_t offset = 0; offset < msg.size(); offset += 64) { + uint32_t w[64]; + for (size_t i = 0; i < 16; ++i) { + const size_t idx = offset + i * 4; + w[i] = (static_cast(msg[idx]) << 24) | + (static_cast(msg[idx + 1]) << 16) | + (static_cast(msg[idx + 2]) << 8) | + (static_cast(msg[idx + 3])); + } + for (size_t i = 16; i < 64; ++i) { + const uint32_t s0 = RotateRight(w[i - 15], 7) ^ RotateRight(w[i - 15], 18) ^ (w[i - 15] >> 3); + const uint32_t s1 = RotateRight(w[i - 2], 17) ^ RotateRight(w[i - 2], 19) ^ (w[i - 2] >> 10); + w[i] = w[i - 16] + s0 + w[i - 7] + s1; + } + + uint32_t a = h[0]; + uint32_t b = h[1]; + uint32_t c = h[2]; + uint32_t d = h[3]; + uint32_t e = h[4]; + uint32_t f = h[5]; + uint32_t g = h[6]; + uint32_t h0 = h[7]; + + for (size_t i = 0; i < 64; ++i) { + const uint32_t s1 = RotateRight(e, 6) ^ RotateRight(e, 11) ^ RotateRight(e, 25); + const uint32_t ch = (e & f) ^ (~e & g); + const uint32_t temp1 = h0 + s1 + ch + k[i] + w[i]; + const uint32_t s0 = RotateRight(a, 2) ^ RotateRight(a, 13) ^ RotateRight(a, 22); + const uint32_t maj = (a & b) ^ (a & c) ^ (b & c); + const uint32_t temp2 = s0 + maj; + + h0 = g; + g = f; + f = e; + e = d + temp1; + d = c; + c = b; + b = a; + a = temp1 + temp2; + } + + h[0] += a; + h[1] += b; + h[2] += c; + h[3] += d; + h[4] += e; + h[5] += f; + h[6] += g; + h[7] += h0; + } + + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + for (size_t i = 0; i < 8; ++i) { + oss << std::setw(8) << h[i]; + } + return oss.str(); +} + +} // namespace + +DocNode::DocNode() + : uid_(GenerateUUID()), + group_(), + text_(), + content_is_list_(false), + parent_(nullptr), + children_loaded_(false), + content_hash_(), + content_hash_dirty_(true), + relevance_score_(0.0), + has_relevance_score_(false), + similarity_score_(0.0), + has_similarity_score_(false) {} + +DocNode::DocNode(const std::string& text) : DocNode() { + set_text(text); +} + +DocNode::DocNode(const DocNode& other) + : uid_(other.uid_), + group_(other.group_), + text_(other.text_), + content_is_list_(other.content_is_list_), + content_list_(other.content_list_), + embedding_(other.embedding_), + metadata_(other.metadata_), + global_metadata_(other.global_metadata_), + excluded_embed_metadata_keys_(other.excluded_embed_metadata_keys_), + excluded_llm_metadata_keys_(other.excluded_llm_metadata_keys_), + parent_(other.parent_), + children_(other.children_), + children_loaded_(other.children_loaded_), + embedding_state_(other.embedding_state_), + content_hash_(other.content_hash_), + content_hash_dirty_(other.content_hash_dirty_), + relevance_score_(other.relevance_score_), + has_relevance_score_(other.has_relevance_score_), + similarity_score_(other.similarity_score_), + has_similarity_score_(other.has_similarity_score_) {} + +DocNode& DocNode::operator=(const DocNode& other) { + if (this == &other) { + return *this; + } + uid_ = other.uid_; + group_ = other.group_; + text_ = other.text_; + content_is_list_ = other.content_is_list_; + content_list_ = other.content_list_; + embedding_ = other.embedding_; + metadata_ = other.metadata_; + global_metadata_ = other.global_metadata_; + excluded_embed_metadata_keys_ = other.excluded_embed_metadata_keys_; + excluded_llm_metadata_keys_ = other.excluded_llm_metadata_keys_; + parent_ = other.parent_; + children_ = other.children_; + children_loaded_ = other.children_loaded_; + embedding_state_ = other.embedding_state_; + content_hash_ = other.content_hash_; + content_hash_dirty_ = other.content_hash_dirty_; + relevance_score_ = other.relevance_score_; + has_relevance_score_ = other.has_relevance_score_; + similarity_score_ = other.similarity_score_; + has_similarity_score_ = other.has_similarity_score_; + return *this; +} + +const std::string& DocNode::uid() const { + return uid_; +} + +const std::string& DocNode::group() const { + return group_; +} + +void DocNode::set_group(const std::string& group) { + group_ = group; +} + +bool DocNode::content_is_list() const { + return content_is_list_; +} + +const std::vector& DocNode::content_list() const { + return content_list_; +} -DocNode::DocNode(const std::string& text) : _text(text) {} +const std::string& DocNode::content_text() const { + return text_; +} + +void DocNode::set_content(const std::string& text) { + set_text(text); +} + +void DocNode::set_content(const std::vector& lines) { + content_is_list_ = true; + content_list_ = lines; + text_ = JoinLines(lines); + invalidate_content_hash(); +} void DocNode::set_text(const std::string& text) { - _text = text; + text_ = text; + content_is_list_ = false; + content_list_.clear(); + invalidate_content_hash(); } const std::string& DocNode::get_text() const { - return _text; + return text_; +} + +std::string DocNode::get_text_with_metadata(MetadataMode mode) const { + const std::string metadata_str = get_metadata_str(mode); + if (metadata_str.empty()) { + return text_; + } + if (text_.empty()) { + return metadata_str; + } + return metadata_str + "\n\n" + text_; +} + +std::string DocNode::content_hash() const { + if (content_hash_dirty_) { + content_hash_ = Sha256Hex(text_); + content_hash_dirty_ = false; + } + return content_hash_; +} + +DocNode::Embedding& DocNode::embedding() { + return embedding_; +} + +const DocNode::Embedding& DocNode::embedding() const { + return embedding_; +} + +void DocNode::set_embedding(const Embedding& embed) { + std::lock_guard lock(embedding_mutex_); + embedding_ = embed; +} + +std::vector DocNode::has_missing_embedding(const std::vector& embed_keys) const { + std::vector missing; + if (embed_keys.empty()) { + return missing; + } + std::lock_guard lock(embedding_mutex_); + for (const auto& key : embed_keys) { + if (embedding_.find(key) == embedding_.end()) { + missing.push_back(key); + } + } + return missing; +} + +void DocNode::do_embedding(const std::unordered_map& embed) { + Embedding generated; + const std::string input = get_text_with_metadata(MetadataMode::Embed); + for (const auto& item : embed) { + generated[item.first] = item.second(input, ""); + } + std::lock_guard lock(embedding_mutex_); + for (const auto& item : generated) { + embedding_[item.first] = item.second; + } +} + +void DocNode::set_embedding_value(const std::string& key, const std::vector& value) { + std::lock_guard lock(embedding_mutex_); + embedding_[key] = value; +} + +void DocNode::check_embedding_state(const std::string& embed_key) const { + while (true) { + { + std::lock_guard lock(embedding_mutex_); + if (embedding_.find(embed_key) != embedding_.end()) { + embedding_state_.erase(embed_key); + break; + } + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } +} + +DocNode* DocNode::parent() { + return parent_; +} + +const DocNode* DocNode::parent() const { + return parent_; +} + +void DocNode::set_parent(DocNode* parent) { + parent_ = parent; +} + +DocNode::Children& DocNode::children() { + return children_; +} + +const DocNode::Children& DocNode::children() const { + return children_; +} + +void DocNode::set_children(const Children& children) { + children_ = children; +} + +DocNode* DocNode::root_node() { + DocNode* node = this; + while (node->parent_ != nullptr) { + node = node->parent_; + } + return node; +} + +const DocNode* DocNode::root_node() const { + const DocNode* node = this; + while (node->parent_ != nullptr) { + node = node->parent_; + } + return node; +} + +bool DocNode::is_root_node() const { + return parent_ == nullptr; +} + +DocNode::Metadata& DocNode::metadata() { + return metadata_; +} + +const DocNode::Metadata& DocNode::metadata() const { + return metadata_; +} + +void DocNode::set_metadata(const Metadata& metadata) { + metadata_ = metadata; +} + +DocNode::Metadata& DocNode::global_metadata() { + return root_node()->global_metadata_; +} + +const DocNode::Metadata& DocNode::global_metadata() const { + return root_node()->global_metadata_; +} + +void DocNode::set_global_metadata(const Metadata& global_metadata) { + global_metadata_ = global_metadata; +} + +std::vector DocNode::excluded_embed_metadata_keys() const { + std::set keys; + const DocNode* root = root_node(); + keys.insert(root->excluded_embed_metadata_keys_.begin(), root->excluded_embed_metadata_keys_.end()); + keys.insert(excluded_embed_metadata_keys_.begin(), excluded_embed_metadata_keys_.end()); + return std::vector(keys.begin(), keys.end()); +} + +void DocNode::set_excluded_embed_metadata_keys(const std::vector& keys) { + excluded_embed_metadata_keys_ = keys; +} + +std::vector DocNode::excluded_llm_metadata_keys() const { + std::set keys; + const DocNode* root = root_node(); + keys.insert(root->excluded_llm_metadata_keys_.begin(), root->excluded_llm_metadata_keys_.end()); + keys.insert(excluded_llm_metadata_keys_.begin(), excluded_llm_metadata_keys_.end()); + return std::vector(keys.begin(), keys.end()); +} + +void DocNode::set_excluded_llm_metadata_keys(const std::vector& keys) { + excluded_llm_metadata_keys_ = keys; +} + +std::string DocNode::docpath() const { + const auto& meta = global_metadata(); + const auto it = meta.find(kRagDocPath); + if (it == meta.end()) { + return ""; + } + return it->second; +} + +void DocNode::set_docpath(const std::string& path) { + if (!is_root_node()) { + throw std::runtime_error("Only root node can set docpath."); + } + global_metadata()[kRagDocPath] = path; +} + +std::string DocNode::get_children_str() const { + std::ostringstream oss; + oss << "{"; + bool first_group = true; + for (const auto& item : children_) { + if (!first_group) { + oss << ", "; + } + first_group = false; + oss << item.first << ": ["; + bool first_child = true; + for (const auto* node : item.second) { + if (!node) { + continue; + } + if (!first_child) { + oss << ", "; + } + first_child = false; + oss << node->uid(); + } + oss << "]"; + } + oss << "}"; + return oss.str(); +} + +std::string DocNode::get_parent_id() const { + return parent_ ? parent_->uid() : ""; +} + +std::string DocNode::to_string() const { + std::ostringstream oss; + oss << "DocNode(id: " << uid_ << ", group: " << group_ << ", content: " << text_ << ") parent: " + << get_parent_id() << ", children: " << get_children_str(); + return oss.str(); +} + +bool DocNode::operator==(const DocNode& other) const { + return uid_ == other.uid_; +} + +bool DocNode::operator!=(const DocNode& other) const { + return !(*this == other); +} + +std::size_t DocNode::hash() const { + return std::hash()(uid_); +} + +std::string DocNode::get_metadata_str(MetadataMode mode) const { + if (mode == MetadataMode::None) { + return ""; + } + std::set keys; + for (const auto& item : metadata_) { + keys.insert(item.first); + } + if (mode == MetadataMode::Llm) { + const auto excluded = excluded_llm_metadata_keys(); + for (const auto& key : excluded) { + keys.erase(key); + } + } else if (mode == MetadataMode::Embed) { + const auto excluded = excluded_embed_metadata_keys(); + for (const auto& key : excluded) { + keys.erase(key); + } + } + std::ostringstream oss; + bool first = true; + for (const auto& key : keys) { + const auto it = metadata_.find(key); + if (it == metadata_.end()) { + continue; + } + if (!first) { + oss << "\n"; + } + first = false; + oss << key << ": " << it->second; + } + return oss.str(); +} + +std::string DocNode::get_content(MetadataMode mode) const { + if (mode == MetadataMode::Llm) { + return get_text_with_metadata(MetadataMode::Llm); + } + return get_text_with_metadata(mode); +} + +DocNode DocNode::with_score(double score) const { + DocNode node(*this); + node.relevance_score_ = score; + node.has_relevance_score_ = true; + return node; +} + +DocNode DocNode::with_sim_score(double score) const { + DocNode node(*this); + node.similarity_score_ = score; + node.has_similarity_score_ = true; + return node; +} + +bool DocNode::has_relevance_score() const { + return has_relevance_score_; +} + +bool DocNode::has_similarity_score() const { + return has_similarity_score_; +} + +double DocNode::relevance_score() const { + return relevance_score_; +} + +double DocNode::similarity_score() const { + return similarity_score_; +} + +void DocNode::invalidate_content_hash() { + content_hash_dirty_ = true; +} + +QADocNode::QADocNode(const std::string& query, const std::string& answer) + : DocNode(query), answer_(Trim(answer)) {} + +QADocNode::QADocNode(const std::string& query, const std::string& answer, const std::string& uid, + const std::string& group) + : DocNode(query), answer_(Trim(answer)) { + if (!uid.empty()) { + uid_ = uid; + } + group_ = group; +} + +const std::string& QADocNode::answer() const { + return answer_; +} + +std::string QADocNode::get_text_with_metadata(MetadataMode mode) const { + if (mode == MetadataMode::Llm) { + std::ostringstream oss; + oss << "query:\n" << text_ << "\nanswer\n" << answer_; + return oss.str(); + } + return DocNode::get_text_with_metadata(mode); +} + +ImageDocNode::ImageDocNode(const std::string& image_path) + : DocNode(image_path), image_path_(Trim(image_path)), modality_("image") { + set_text(image_path_); +} + +ImageDocNode::ImageDocNode(const std::string& image_path, const std::string& uid, const std::string& group) + : DocNode(image_path), image_path_(Trim(image_path)), modality_("image") { + if (!uid.empty()) { + uid_ = uid; + } + group_ = group; + set_text(image_path_); +} + +const std::string& ImageDocNode::image_path() const { + return image_path_; +} + +std::string ImageDocNode::get_content(MetadataMode mode) const { + if (mode == MetadataMode::Embed) { + std::string file_bytes; + if (!ReadFileBinary(image_path_, &file_bytes)) { + return ""; + } + const std::string mime = ImageMimeType(image_path_); + const std::string base64 = Base64Encode(file_bytes); + if (mime.empty()) { + return base64; + } + return "data:" + mime + ";base64," + base64; + } + return image_path_; +} + +void ImageDocNode::do_embedding(const std::unordered_map& embed) { + Embedding generated; + const std::string input = get_content(MetadataMode::Embed); + for (const auto& item : embed) { + generated[item.first] = item.second(input, modality_); + } + std::lock_guard lock(embedding_mutex_); + for (const auto& item : generated) { + embedding_[item.first] = item.second; + } +} + +std::string ImageDocNode::get_text_with_metadata(MetadataMode mode) const { + (void)mode; + return image_path_; } } // namespace lazyllm From 54e43c17b39054b3aef6b7151195b3f9bf09923b Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 21 Jan 2026 10:44:49 +0800 Subject: [PATCH 03/40] underline --- csrc/include/doc_node.h | 48 +++---- csrc/src/doc_node.cpp | 290 ++++++++++++++++++++-------------------- 2 files changed, 169 insertions(+), 169 deletions(-) diff --git a/csrc/include/doc_node.h b/csrc/include/doc_node.h index cb819d538..d8aaa6219 100644 --- a/csrc/include/doc_node.h +++ b/csrc/include/doc_node.h @@ -110,27 +110,27 @@ class DocNode { protected: void invalidate_content_hash(); - std::string uid_; - std::string group_; - std::string text_; - bool content_is_list_; - std::vector content_list_; - Embedding embedding_; - Metadata metadata_; - Metadata global_metadata_; - std::vector excluded_embed_metadata_keys_; - std::vector excluded_llm_metadata_keys_; - DocNode* parent_; - Children children_; - bool children_loaded_; - mutable std::mutex embedding_mutex_; - mutable std::set embedding_state_; - mutable std::string content_hash_; - mutable bool content_hash_dirty_; - double relevance_score_; - bool has_relevance_score_; - double similarity_score_; - bool has_similarity_score_; + std::string _uid; + std::string _group; + std::string _text; + bool _content_is_list; + std::vector _content_list; + Embedding _embedding; + Metadata _metadata; + Metadata _global_metadata; + std::vector _excluded_embed_metadata_keys; + std::vector _excluded_llm_metadata_keys; + DocNode* _parent; + Children _children; + bool _children_loaded; + mutable std::mutex _embedding_mutex; + mutable std::set _embedding_state; + mutable std::string _content_hash; + mutable bool _content_hash_dirty; + double _relevance_score; + bool _has_relevance_score; + double _similarity_score; + bool _has_similarity_score; }; class QADocNode : public DocNode { @@ -142,7 +142,7 @@ class QADocNode : public DocNode { std::string get_text_with_metadata(MetadataMode mode) const override; private: - std::string answer_; + std::string _answer; }; class ImageDocNode : public DocNode { @@ -157,8 +157,8 @@ class ImageDocNode : public DocNode { std::string get_text_with_metadata(MetadataMode mode) const override; private: - std::string image_path_; - std::string modality_; + std::string _image_path; + std::string _modality; }; } // namespace lazyllm diff --git a/csrc/src/doc_node.cpp b/csrc/src/doc_node.cpp index 6baf8107a..e26c568a3 100644 --- a/csrc/src/doc_node.cpp +++ b/csrc/src/doc_node.cpp @@ -229,94 +229,94 @@ std::string Sha256Hex(const std::string& input) { } // namespace DocNode::DocNode() - : uid_(GenerateUUID()), - group_(), - text_(), - content_is_list_(false), - parent_(nullptr), - children_loaded_(false), - content_hash_(), - content_hash_dirty_(true), - relevance_score_(0.0), - has_relevance_score_(false), - similarity_score_(0.0), - has_similarity_score_(false) {} + : _uid(GenerateUUID()), + _group(), + _text(), + _content_is_list(false), + _parent(nullptr), + _children_loaded(false), + _content_hash(), + _content_hash_dirty(true), + _relevance_score(0.0), + _has_relevance_score(false), + _similarity_score(0.0), + _has_similarity_score(false) {} DocNode::DocNode(const std::string& text) : DocNode() { set_text(text); } DocNode::DocNode(const DocNode& other) - : uid_(other.uid_), - group_(other.group_), - text_(other.text_), - content_is_list_(other.content_is_list_), - content_list_(other.content_list_), - embedding_(other.embedding_), - metadata_(other.metadata_), - global_metadata_(other.global_metadata_), - excluded_embed_metadata_keys_(other.excluded_embed_metadata_keys_), - excluded_llm_metadata_keys_(other.excluded_llm_metadata_keys_), - parent_(other.parent_), - children_(other.children_), - children_loaded_(other.children_loaded_), - embedding_state_(other.embedding_state_), - content_hash_(other.content_hash_), - content_hash_dirty_(other.content_hash_dirty_), - relevance_score_(other.relevance_score_), - has_relevance_score_(other.has_relevance_score_), - similarity_score_(other.similarity_score_), - has_similarity_score_(other.has_similarity_score_) {} + : _uid(other._uid), + _group(other._group), + _text(other._text), + _content_is_list(other._content_is_list), + _content_list(other._content_list), + _embedding(other._embedding), + _metadata(other._metadata), + _global_metadata(other._global_metadata), + _excluded_embed_metadata_keys(other._excluded_embed_metadata_keys), + _excluded_llm_metadata_keys(other._excluded_llm_metadata_keys), + _parent(other._parent), + _children(other._children), + _children_loaded(other._children_loaded), + _embedding_state(other._embedding_state), + _content_hash(other._content_hash), + _content_hash_dirty(other._content_hash_dirty), + _relevance_score(other._relevance_score), + _has_relevance_score(other._has_relevance_score), + _similarity_score(other._similarity_score), + _has_similarity_score(other._has_similarity_score) {} DocNode& DocNode::operator=(const DocNode& other) { if (this == &other) { return *this; } - uid_ = other.uid_; - group_ = other.group_; - text_ = other.text_; - content_is_list_ = other.content_is_list_; - content_list_ = other.content_list_; - embedding_ = other.embedding_; - metadata_ = other.metadata_; - global_metadata_ = other.global_metadata_; - excluded_embed_metadata_keys_ = other.excluded_embed_metadata_keys_; - excluded_llm_metadata_keys_ = other.excluded_llm_metadata_keys_; - parent_ = other.parent_; - children_ = other.children_; - children_loaded_ = other.children_loaded_; - embedding_state_ = other.embedding_state_; - content_hash_ = other.content_hash_; - content_hash_dirty_ = other.content_hash_dirty_; - relevance_score_ = other.relevance_score_; - has_relevance_score_ = other.has_relevance_score_; - similarity_score_ = other.similarity_score_; - has_similarity_score_ = other.has_similarity_score_; + _uid = other._uid; + _group = other._group; + _text = other._text; + _content_is_list = other._content_is_list; + _content_list = other._content_list; + _embedding = other._embedding; + _metadata = other._metadata; + _global_metadata = other._global_metadata; + _excluded_embed_metadata_keys = other._excluded_embed_metadata_keys; + _excluded_llm_metadata_keys = other._excluded_llm_metadata_keys; + _parent = other._parent; + _children = other._children; + _children_loaded = other._children_loaded; + _embedding_state = other._embedding_state; + _content_hash = other._content_hash; + _content_hash_dirty = other._content_hash_dirty; + _relevance_score = other._relevance_score; + _has_relevance_score = other._has_relevance_score; + _similarity_score = other._similarity_score; + _has_similarity_score = other._has_similarity_score; return *this; } const std::string& DocNode::uid() const { - return uid_; + return _uid; } const std::string& DocNode::group() const { - return group_; + return _group; } void DocNode::set_group(const std::string& group) { - group_ = group; + _group = group; } bool DocNode::content_is_list() const { - return content_is_list_; + return _content_is_list; } const std::vector& DocNode::content_list() const { - return content_list_; + return _content_list; } const std::string& DocNode::content_text() const { - return text_; + return _text; } void DocNode::set_content(const std::string& text) { @@ -324,53 +324,53 @@ void DocNode::set_content(const std::string& text) { } void DocNode::set_content(const std::vector& lines) { - content_is_list_ = true; - content_list_ = lines; - text_ = JoinLines(lines); + _content_is_list = true; + _content_list = lines; + _text = JoinLines(lines); invalidate_content_hash(); } void DocNode::set_text(const std::string& text) { - text_ = text; - content_is_list_ = false; - content_list_.clear(); + _text = text; + _content_is_list = false; + _content_list.clear(); invalidate_content_hash(); } const std::string& DocNode::get_text() const { - return text_; + return _text; } std::string DocNode::get_text_with_metadata(MetadataMode mode) const { const std::string metadata_str = get_metadata_str(mode); if (metadata_str.empty()) { - return text_; + return _text; } - if (text_.empty()) { + if (_text.empty()) { return metadata_str; } - return metadata_str + "\n\n" + text_; + return metadata_str + "\n\n" + _text; } std::string DocNode::content_hash() const { - if (content_hash_dirty_) { - content_hash_ = Sha256Hex(text_); - content_hash_dirty_ = false; + if (_content_hash_dirty) { + _content_hash = Sha256Hex(_text); + _content_hash_dirty = false; } - return content_hash_; + return _content_hash; } DocNode::Embedding& DocNode::embedding() { - return embedding_; + return _embedding; } const DocNode::Embedding& DocNode::embedding() const { - return embedding_; + return _embedding; } void DocNode::set_embedding(const Embedding& embed) { - std::lock_guard lock(embedding_mutex_); - embedding_ = embed; + std::lock_guard lock(_embedding_mutex); + _embedding = embed; } std::vector DocNode::has_missing_embedding(const std::vector& embed_keys) const { @@ -378,9 +378,9 @@ std::vector DocNode::has_missing_embedding(const std::vector lock(embedding_mutex_); + std::lock_guard lock(_embedding_mutex); for (const auto& key : embed_keys) { - if (embedding_.find(key) == embedding_.end()) { + if (_embedding.find(key) == _embedding.end()) { missing.push_back(key); } } @@ -393,23 +393,23 @@ void DocNode::do_embedding(const std::unordered_map& e for (const auto& item : embed) { generated[item.first] = item.second(input, ""); } - std::lock_guard lock(embedding_mutex_); + std::lock_guard lock(_embedding_mutex); for (const auto& item : generated) { - embedding_[item.first] = item.second; + _embedding[item.first] = item.second; } } void DocNode::set_embedding_value(const std::string& key, const std::vector& value) { - std::lock_guard lock(embedding_mutex_); - embedding_[key] = value; + std::lock_guard lock(_embedding_mutex); + _embedding[key] = value; } void DocNode::check_embedding_state(const std::string& embed_key) const { while (true) { { - std::lock_guard lock(embedding_mutex_); - if (embedding_.find(embed_key) != embedding_.end()) { - embedding_state_.erase(embed_key); + std::lock_guard lock(_embedding_mutex); + if (_embedding.find(embed_key) != _embedding.end()) { + _embedding_state.erase(embed_key); break; } } @@ -418,95 +418,95 @@ void DocNode::check_embedding_state(const std::string& embed_key) const { } DocNode* DocNode::parent() { - return parent_; + return _parent; } const DocNode* DocNode::parent() const { - return parent_; + return _parent; } void DocNode::set_parent(DocNode* parent) { - parent_ = parent; + _parent = parent; } DocNode::Children& DocNode::children() { - return children_; + return _children; } const DocNode::Children& DocNode::children() const { - return children_; + return _children; } void DocNode::set_children(const Children& children) { - children_ = children; + _children = children; } DocNode* DocNode::root_node() { DocNode* node = this; - while (node->parent_ != nullptr) { - node = node->parent_; + while (node->_parent != nullptr) { + node = node->_parent; } return node; } const DocNode* DocNode::root_node() const { const DocNode* node = this; - while (node->parent_ != nullptr) { - node = node->parent_; + while (node->_parent != nullptr) { + node = node->_parent; } return node; } bool DocNode::is_root_node() const { - return parent_ == nullptr; + return _parent == nullptr; } DocNode::Metadata& DocNode::metadata() { - return metadata_; + return _metadata; } const DocNode::Metadata& DocNode::metadata() const { - return metadata_; + return _metadata; } void DocNode::set_metadata(const Metadata& metadata) { - metadata_ = metadata; + _metadata = metadata; } DocNode::Metadata& DocNode::global_metadata() { - return root_node()->global_metadata_; + return root_node()->_global_metadata; } const DocNode::Metadata& DocNode::global_metadata() const { - return root_node()->global_metadata_; + return root_node()->_global_metadata; } void DocNode::set_global_metadata(const Metadata& global_metadata) { - global_metadata_ = global_metadata; + _global_metadata = global_metadata; } std::vector DocNode::excluded_embed_metadata_keys() const { std::set keys; const DocNode* root = root_node(); - keys.insert(root->excluded_embed_metadata_keys_.begin(), root->excluded_embed_metadata_keys_.end()); - keys.insert(excluded_embed_metadata_keys_.begin(), excluded_embed_metadata_keys_.end()); + keys.insert(root->_excluded_embed_metadata_keys.begin(), root->_excluded_embed_metadata_keys.end()); + keys.insert(_excluded_embed_metadata_keys.begin(), _excluded_embed_metadata_keys.end()); return std::vector(keys.begin(), keys.end()); } void DocNode::set_excluded_embed_metadata_keys(const std::vector& keys) { - excluded_embed_metadata_keys_ = keys; + _excluded_embed_metadata_keys = keys; } std::vector DocNode::excluded_llm_metadata_keys() const { std::set keys; const DocNode* root = root_node(); - keys.insert(root->excluded_llm_metadata_keys_.begin(), root->excluded_llm_metadata_keys_.end()); - keys.insert(excluded_llm_metadata_keys_.begin(), excluded_llm_metadata_keys_.end()); + keys.insert(root->_excluded_llm_metadata_keys.begin(), root->_excluded_llm_metadata_keys.end()); + keys.insert(_excluded_llm_metadata_keys.begin(), _excluded_llm_metadata_keys.end()); return std::vector(keys.begin(), keys.end()); } void DocNode::set_excluded_llm_metadata_keys(const std::vector& keys) { - excluded_llm_metadata_keys_ = keys; + _excluded_llm_metadata_keys = keys; } std::string DocNode::docpath() const { @@ -529,7 +529,7 @@ std::string DocNode::get_children_str() const { std::ostringstream oss; oss << "{"; bool first_group = true; - for (const auto& item : children_) { + for (const auto& item : _children) { if (!first_group) { oss << ", "; } @@ -553,18 +553,18 @@ std::string DocNode::get_children_str() const { } std::string DocNode::get_parent_id() const { - return parent_ ? parent_->uid() : ""; + return _parent ? _parent->uid() : ""; } std::string DocNode::to_string() const { std::ostringstream oss; - oss << "DocNode(id: " << uid_ << ", group: " << group_ << ", content: " << text_ << ") parent: " + oss << "DocNode(id: " << _uid << ", group: " << _group << ", content: " << _text << ") parent: " << get_parent_id() << ", children: " << get_children_str(); return oss.str(); } bool DocNode::operator==(const DocNode& other) const { - return uid_ == other.uid_; + return _uid == other._uid; } bool DocNode::operator!=(const DocNode& other) const { @@ -572,7 +572,7 @@ bool DocNode::operator!=(const DocNode& other) const { } std::size_t DocNode::hash() const { - return std::hash()(uid_); + return std::hash()(_uid); } std::string DocNode::get_metadata_str(MetadataMode mode) const { @@ -580,7 +580,7 @@ std::string DocNode::get_metadata_str(MetadataMode mode) const { return ""; } std::set keys; - for (const auto& item : metadata_) { + for (const auto& item : _metadata) { keys.insert(item.first); } if (mode == MetadataMode::Llm) { @@ -597,8 +597,8 @@ std::string DocNode::get_metadata_str(MetadataMode mode) const { std::ostringstream oss; bool first = true; for (const auto& key : keys) { - const auto it = metadata_.find(key); - if (it == metadata_.end()) { + const auto it = _metadata.find(key); + if (it == _metadata.end()) { continue; } if (!first) { @@ -619,112 +619,112 @@ std::string DocNode::get_content(MetadataMode mode) const { DocNode DocNode::with_score(double score) const { DocNode node(*this); - node.relevance_score_ = score; - node.has_relevance_score_ = true; + node._relevance_score = score; + node._has_relevance_score = true; return node; } DocNode DocNode::with_sim_score(double score) const { DocNode node(*this); - node.similarity_score_ = score; - node.has_similarity_score_ = true; + node._similarity_score = score; + node._has_similarity_score = true; return node; } bool DocNode::has_relevance_score() const { - return has_relevance_score_; + return _has_relevance_score; } bool DocNode::has_similarity_score() const { - return has_similarity_score_; + return _has_similarity_score; } double DocNode::relevance_score() const { - return relevance_score_; + return _relevance_score; } double DocNode::similarity_score() const { - return similarity_score_; + return _similarity_score; } void DocNode::invalidate_content_hash() { - content_hash_dirty_ = true; + _content_hash_dirty = true; } QADocNode::QADocNode(const std::string& query, const std::string& answer) - : DocNode(query), answer_(Trim(answer)) {} + : DocNode(query), _answer(Trim(answer)) {} QADocNode::QADocNode(const std::string& query, const std::string& answer, const std::string& uid, const std::string& group) - : DocNode(query), answer_(Trim(answer)) { + : DocNode(query), _answer(Trim(answer)) { if (!uid.empty()) { - uid_ = uid; + _uid = uid; } - group_ = group; + _group = group; } const std::string& QADocNode::answer() const { - return answer_; + return _answer; } std::string QADocNode::get_text_with_metadata(MetadataMode mode) const { if (mode == MetadataMode::Llm) { std::ostringstream oss; - oss << "query:\n" << text_ << "\nanswer\n" << answer_; + oss << "query:\n" << _text << "\nanswer\n" << _answer; return oss.str(); } return DocNode::get_text_with_metadata(mode); } ImageDocNode::ImageDocNode(const std::string& image_path) - : DocNode(image_path), image_path_(Trim(image_path)), modality_("image") { - set_text(image_path_); + : DocNode(image_path), _image_path(Trim(image_path)), _modality("image") { + set_text(_image_path); } ImageDocNode::ImageDocNode(const std::string& image_path, const std::string& uid, const std::string& group) - : DocNode(image_path), image_path_(Trim(image_path)), modality_("image") { + : DocNode(image_path), _image_path(Trim(image_path)), _modality("image") { if (!uid.empty()) { - uid_ = uid; + _uid = uid; } - group_ = group; - set_text(image_path_); + _group = group; + set_text(_image_path); } const std::string& ImageDocNode::image_path() const { - return image_path_; + return _image_path; } std::string ImageDocNode::get_content(MetadataMode mode) const { if (mode == MetadataMode::Embed) { std::string file_bytes; - if (!ReadFileBinary(image_path_, &file_bytes)) { + if (!ReadFileBinary(_image_path, &file_bytes)) { return ""; } - const std::string mime = ImageMimeType(image_path_); + const std::string mime = ImageMimeType(_image_path); const std::string base64 = Base64Encode(file_bytes); if (mime.empty()) { return base64; } return "data:" + mime + ";base64," + base64; } - return image_path_; + return _image_path; } void ImageDocNode::do_embedding(const std::unordered_map& embed) { Embedding generated; const std::string input = get_content(MetadataMode::Embed); for (const auto& item : embed) { - generated[item.first] = item.second(input, modality_); + generated[item.first] = item.second(input, _modality); } - std::lock_guard lock(embedding_mutex_); + std::lock_guard lock(_embedding_mutex); for (const auto& item : generated) { - embedding_[item.first] = item.second; + _embedding[item.first] = item.second; } } std::string ImageDocNode::get_text_with_metadata(MetadataMode mode) const { (void)mode; - return image_path_; + return _image_path; } } // namespace lazyllm From 894bb73e3ac244220d2feb51e6261d0ca526227c Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 22 Jan 2026 14:41:32 +0800 Subject: [PATCH 04/40] c++17 --- csrc/CMakeLists.txt | 2 +- csrc/cmake/tests.cmake | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9d0e44ebd..58e88040d 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.16) project(LazyLLMCPP LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) diff --git a/csrc/cmake/tests.cmake b/csrc/cmake/tests.cmake index 3eb114d9e..e495d3cbf 100644 --- a/csrc/cmake/tests.cmake +++ b/csrc/cmake/tests.cmake @@ -2,7 +2,7 @@ include(FetchContent) FetchContent_Declare( googletest - URL https://codeload.github.com/google/googletest/zip/refs/tags/release-1.12.1 + URL https://codeload.github.com/google/googletest/zip/refs/tags/release-1.17.0 ) # Fix gtest version to maintain C++11 compatibility. set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) From b0a1ca02c1b0b46a8bdc211334827eb3ca58cc5c Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 22 Jan 2026 16:19:43 +0800 Subject: [PATCH 05/40] rename --- csrc/include/{doc_node.h => doc_node.hpp} | 54 ++--- csrc/include/utils.hpp | 224 +++++++++++++++++++ csrc/src/doc_node.cpp | 252 ++-------------------- csrc/src/utils.cpp | 0 4 files changed, 270 insertions(+), 260 deletions(-) rename csrc/include/{doc_node.h => doc_node.hpp} (80%) create mode 100644 csrc/include/utils.hpp create mode 100644 csrc/src/utils.cpp diff --git a/csrc/include/doc_node.h b/csrc/include/doc_node.hpp similarity index 80% rename from csrc/include/doc_node.h rename to csrc/include/doc_node.hpp index d8aaa6219..109f245bf 100644 --- a/csrc/include/doc_node.h +++ b/csrc/include/doc_node.hpp @@ -1,34 +1,37 @@ #pragma once -#include #include -#include #include #include #include #include +#include +#include namespace lazyllm { -constexpr const char kRagKbId[] = "kb_id"; -constexpr const char kRagDocId[] = "docid"; -constexpr const char kRagDocPath[] = "lazyllm_doc_path"; - -enum class MetadataMode { - All, - Embed, - Llm, - None, -}; +enum class MetadataMode { ALL, EMBED, LLM, NONE }; class DocNode { -public: - using Embedding = std::unordered_map>; +protected: using Metadata = std::unordered_map; using Children = std::unordered_map>; - using EmbeddingFn = std::function(const std::string&, const std::string&)>; + using EmbeddingFun = std::function(const std::string&, const std::string&)>; + using EmbeddingVec = std::unordered_map>; - DocNode(); +public: + DocNode( + uid: Optional[str] = None, + content: Optional[Union[str, List[Any]]] = None, + group: Optional[str] = None, + embedding: Optional[Dict[str,List[float]]] = None, + parent: Optional[Union[str, 'DocNode']] = None, + store=None, + node_groups: Optional[Dict[str, Dict]] = None, + metadata: Optional[Dict[str, Any]] = None, + global_metadata: Optional[Dict[str, Any]] = None, + text: Optional[str] = None + ); explicit DocNode(const std::string& text); DocNode(const DocNode& other); DocNode& operator=(const DocNode& other); @@ -51,12 +54,12 @@ class DocNode { std::string content_hash() const; - Embedding& embedding(); - const Embedding& embedding() const; - void set_embedding(const Embedding& embed); + EmbeddingVec& embedding(); + const EmbeddingVec& embedding() const; + void set_embedding(const EmbeddingVec& embed); std::vector has_missing_embedding(const std::vector& embed_keys) const; - virtual void do_embedding(const std::unordered_map& embed); + virtual void do_embedding(const std::unordered_map& embed); void set_embedding_value(const std::string& key, const std::vector& value); void check_embedding_state(const std::string& embed_key) const; @@ -96,8 +99,8 @@ class DocNode { bool operator!=(const DocNode& other) const; std::size_t hash() const; - std::string get_metadata_str(MetadataMode mode = MetadataMode::All) const; - virtual std::string get_content(MetadataMode mode = MetadataMode::Llm) const; + std::string get_metadata_str(MetadataMode mode = MetadataMode::ALL) const; + virtual std::string get_content(MetadataMode mode = MetadataMode::LLM) const; DocNode with_score(double score) const; DocNode with_sim_score(double score) const; @@ -115,7 +118,7 @@ class DocNode { std::string _text; bool _content_is_list; std::vector _content_list; - Embedding _embedding; + EmbeddingVec _embedding; Metadata _metadata; Metadata _global_metadata; std::vector _excluded_embed_metadata_keys; @@ -123,7 +126,6 @@ class DocNode { DocNode* _parent; Children _children; bool _children_loaded; - mutable std::mutex _embedding_mutex; mutable std::set _embedding_state; mutable std::string _content_hash; mutable bool _content_hash_dirty; @@ -152,8 +154,8 @@ class ImageDocNode : public DocNode { const std::string& group = std::string()); const std::string& image_path() const; - std::string get_content(MetadataMode mode = MetadataMode::Llm) const override; - void do_embedding(const std::unordered_map& embed) override; + std::string get_content(MetadataMode mode = MetadataMode::LLM) const override; + void do_embedding(const std::unordered_map& embed) override; std::string get_text_with_metadata(MetadataMode mode) const override; private: diff --git a/csrc/include/utils.hpp b/csrc/include/utils.hpp new file mode 100644 index 000000000..25a9d7681 --- /dev/null +++ b/csrc/include/utils.hpp @@ -0,0 +1,224 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace lazyllm { + +std::string JoinLines(const std::vector& lines) { + std::ostringstream oss; + for (size_t i = 0; i < lines.size(); ++i) { + if (i > 0) { + oss << "\n"; + } + oss << lines[i]; + } + return oss.str(); +} + +std::string Trim(const std::string& value) { + size_t start = 0; + while (start < value.size() && std::isspace(static_cast(value[start]))) { + ++start; + } + size_t end = value.size(); + while (end > start && std::isspace(static_cast(value[end - 1]))) { + --end; + } + return value.substr(start, end - start); +} + +std::string GenerateUUID() { + static const char HEX_CHAR[] = "0123456789abcdef"; + static const int SEGS[] = {8, 4, 4, 4, 12}; + + // Single static generator per thread. + static thread_local std::mt19937 GEN(std::random_device{}()); + static thread_local std::uniform_int_distribution DIST(0, 15); + + std::string out; + out.reserve(36); + for (int segLength : SEGS) { + for (int i = 0; i < segLength; ++i) + out.push_back(HEX_CHAR[DIST(GEN)]); + if (segLength < 12) + out.push_back('-'); + } + return out; +} + +std::string ToLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return value; +} + +std::string GetExtension(const std::string& path) { + const size_t dot = path.find_last_of('.'); + if (dot == std::string::npos || dot + 1 >= path.size()) { + return ""; + } + return ToLower(path.substr(dot + 1)); +} + +std::string ImageMimeType(const std::string& path) { + const std::string ext = GetExtension(path); + if (ext == "jpg" || ext == "jpeg" || ext == "jfif" || ext == "jpe") { + return "image/jpeg"; + } + if (ext == "png" || ext == "apng") { + return "image/png"; + } + if (ext == "gif") { + return "image/gif"; + } + if (ext == "bmp" || ext == "dib") { + return "image/bmp"; + } + if (ext == "tif" || ext == "tiff") { + return "image/tiff"; + } + if (ext == "webp") { + return "image/webp"; + } + if (ext == "ico") { + return "image/x-icon"; + } + if (ext == "icns") { + return "image/icns"; + } + return ""; +} + +bool ReadFileBinary(const std::string& path, std::string* out) { + std::ifstream file(path.c_str(), std::ios::binary); + if (!file) { + return false; + } + std::ostringstream buffer; + buffer << file.rdbuf(); + *out = buffer.str(); + return true; +} + +std::string Base64Encode(const std::string& data) { + static const char kBase64Chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string out; + out.reserve(((data.size() + 2) / 3) * 4); + int val = 0; + int valb = -6; + for (unsigned char c : data) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(kBase64Chars[(val >> valb) & 0x3F]); + valb -= 6; + } + } + if (valb > -6) { + out.push_back(kBase64Chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + while (out.size() % 4) { + out.push_back('='); + } + return out; +} + +uint32_t RotateRight(uint32_t value, uint32_t bits) { + return (value >> bits) | (value << (32 - bits)); +} + +std::string Sha256Hex(const std::string& input) { + static const uint32_t k[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, + 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, + 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, + 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, + 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, + 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, + 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, + 0xc67178f2, + }; + uint32_t h[8] = { + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, + }; + + std::vector msg(input.begin(), input.end()); + const uint64_t bit_len = static_cast(msg.size()) * 8; + msg.push_back(0x80); + while ((msg.size() % 64) != 56) { + msg.push_back(0x00); + } + for (int i = 7; i >= 0; --i) { + msg.push_back(static_cast((bit_len >> (i * 8)) & 0xFF)); + } + + for (size_t offset = 0; offset < msg.size(); offset += 64) { + uint32_t w[64]; + for (size_t i = 0; i < 16; ++i) { + const size_t idx = offset + i * 4; + w[i] = (static_cast(msg[idx]) << 24) | + (static_cast(msg[idx + 1]) << 16) | + (static_cast(msg[idx + 2]) << 8) | + (static_cast(msg[idx + 3])); + } + for (size_t i = 16; i < 64; ++i) { + const uint32_t s0 = RotateRight(w[i - 15], 7) ^ RotateRight(w[i - 15], 18) ^ (w[i - 15] >> 3); + const uint32_t s1 = RotateRight(w[i - 2], 17) ^ RotateRight(w[i - 2], 19) ^ (w[i - 2] >> 10); + w[i] = w[i - 16] + s0 + w[i - 7] + s1; + } + + uint32_t a = h[0]; + uint32_t b = h[1]; + uint32_t c = h[2]; + uint32_t d = h[3]; + uint32_t e = h[4]; + uint32_t f = h[5]; + uint32_t g = h[6]; + uint32_t h0 = h[7]; + + for (size_t i = 0; i < 64; ++i) { + const uint32_t s1 = RotateRight(e, 6) ^ RotateRight(e, 11) ^ RotateRight(e, 25); + const uint32_t ch = (e & f) ^ (~e & g); + const uint32_t temp1 = h0 + s1 + ch + k[i] + w[i]; + const uint32_t s0 = RotateRight(a, 2) ^ RotateRight(a, 13) ^ RotateRight(a, 22); + const uint32_t maj = (a & b) ^ (a & c) ^ (b & c); + const uint32_t temp2 = s0 + maj; + + h0 = g; + g = f; + f = e; + e = d + temp1; + d = c; + c = b; + b = a; + a = temp1 + temp2; + } + + h[0] += a; + h[1] += b; + h[2] += c; + h[3] += d; + h[4] += e; + h[5] += f; + h[6] += g; + h[7] += h0; + } + + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + for (size_t i = 0; i < 8; ++i) { + oss << std::setw(8) << h[i]; + } + return oss.str(); +} + +} // namespace lazyllm diff --git a/csrc/src/doc_node.cpp b/csrc/src/doc_node.cpp index e26c568a3..abcf20252 100644 --- a/csrc/src/doc_node.cpp +++ b/csrc/src/doc_node.cpp @@ -1,232 +1,16 @@ -#include "doc_node.h" +#include "doc_node.hpp" +#include "utils.hpp" #include #include #include #include -#include #include -#include #include #include #include namespace lazyllm { -namespace { - -std::string JoinLines(const std::vector& lines) { - std::ostringstream oss; - for (size_t i = 0; i < lines.size(); ++i) { - if (i > 0) { - oss << "\n"; - } - oss << lines[i]; - } - return oss.str(); -} - -std::string Trim(const std::string& value) { - size_t start = 0; - while (start < value.size() && std::isspace(static_cast(value[start]))) { - ++start; - } - size_t end = value.size(); - while (end > start && std::isspace(static_cast(value[end - 1]))) { - --end; - } - return value.substr(start, end - start); -} - -std::string GenerateUUID() { - static const char kHex[] = "0123456789abcdef"; - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dist(0, 15); - const int groups[] = {8, 4, 4, 4, 12}; - std::string out; - out.reserve(36); - for (size_t group = 0; group < 5; ++group) { - if (group > 0) { - out.push_back('-'); - } - for (int i = 0; i < groups[group]; ++i) { - out.push_back(kHex[dist(gen)]); - } - } - return out; -} - -std::string ToLower(std::string value) { - std::transform(value.begin(), value.end(), value.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - return value; -} - -std::string GetExtension(const std::string& path) { - const size_t dot = path.find_last_of('.'); - if (dot == std::string::npos || dot + 1 >= path.size()) { - return ""; - } - return ToLower(path.substr(dot + 1)); -} - -std::string ImageMimeType(const std::string& path) { - const std::string ext = GetExtension(path); - if (ext == "jpg" || ext == "jpeg" || ext == "jfif" || ext == "jpe") { - return "image/jpeg"; - } - if (ext == "png" || ext == "apng") { - return "image/png"; - } - if (ext == "gif") { - return "image/gif"; - } - if (ext == "bmp" || ext == "dib") { - return "image/bmp"; - } - if (ext == "tif" || ext == "tiff") { - return "image/tiff"; - } - if (ext == "webp") { - return "image/webp"; - } - if (ext == "ico") { - return "image/x-icon"; - } - if (ext == "icns") { - return "image/icns"; - } - return ""; -} - -bool ReadFileBinary(const std::string& path, std::string* out) { - std::ifstream file(path.c_str(), std::ios::binary); - if (!file) { - return false; - } - std::ostringstream buffer; - buffer << file.rdbuf(); - *out = buffer.str(); - return true; -} - -std::string Base64Encode(const std::string& data) { - static const char kBase64Chars[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(((data.size() + 2) / 3) * 4); - int val = 0; - int valb = -6; - for (unsigned char c : data) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - out.push_back(kBase64Chars[(val >> valb) & 0x3F]); - valb -= 6; - } - } - if (valb > -6) { - out.push_back(kBase64Chars[((val << 8) >> (valb + 8)) & 0x3F]); - } - while (out.size() % 4) { - out.push_back('='); - } - return out; -} - -uint32_t RotateRight(uint32_t value, uint32_t bits) { - return (value >> bits) | (value << (32 - bits)); -} - -std::string Sha256Hex(const std::string& input) { - static const uint32_t k[64] = { - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, - 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, - 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, - 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, - 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, - 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, - 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, - 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, - 0xc67178f2, - }; - uint32_t h[8] = { - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, - }; - - std::vector msg(input.begin(), input.end()); - const uint64_t bit_len = static_cast(msg.size()) * 8; - msg.push_back(0x80); - while ((msg.size() % 64) != 56) { - msg.push_back(0x00); - } - for (int i = 7; i >= 0; --i) { - msg.push_back(static_cast((bit_len >> (i * 8)) & 0xFF)); - } - - for (size_t offset = 0; offset < msg.size(); offset += 64) { - uint32_t w[64]; - for (size_t i = 0; i < 16; ++i) { - const size_t idx = offset + i * 4; - w[i] = (static_cast(msg[idx]) << 24) | - (static_cast(msg[idx + 1]) << 16) | - (static_cast(msg[idx + 2]) << 8) | - (static_cast(msg[idx + 3])); - } - for (size_t i = 16; i < 64; ++i) { - const uint32_t s0 = RotateRight(w[i - 15], 7) ^ RotateRight(w[i - 15], 18) ^ (w[i - 15] >> 3); - const uint32_t s1 = RotateRight(w[i - 2], 17) ^ RotateRight(w[i - 2], 19) ^ (w[i - 2] >> 10); - w[i] = w[i - 16] + s0 + w[i - 7] + s1; - } - - uint32_t a = h[0]; - uint32_t b = h[1]; - uint32_t c = h[2]; - uint32_t d = h[3]; - uint32_t e = h[4]; - uint32_t f = h[5]; - uint32_t g = h[6]; - uint32_t h0 = h[7]; - - for (size_t i = 0; i < 64; ++i) { - const uint32_t s1 = RotateRight(e, 6) ^ RotateRight(e, 11) ^ RotateRight(e, 25); - const uint32_t ch = (e & f) ^ (~e & g); - const uint32_t temp1 = h0 + s1 + ch + k[i] + w[i]; - const uint32_t s0 = RotateRight(a, 2) ^ RotateRight(a, 13) ^ RotateRight(a, 22); - const uint32_t maj = (a & b) ^ (a & c) ^ (b & c); - const uint32_t temp2 = s0 + maj; - - h0 = g; - g = f; - f = e; - e = d + temp1; - d = c; - c = b; - b = a; - a = temp1 + temp2; - } - - h[0] += a; - h[1] += b; - h[2] += c; - h[3] += d; - h[4] += e; - h[5] += f; - h[6] += g; - h[7] += h0; - } - - std::ostringstream oss; - oss << std::hex << std::setfill('0'); - for (size_t i = 0; i < 8; ++i) { - oss << std::setw(8) << h[i]; - } - return oss.str(); -} - -} // namespace DocNode::DocNode() : _uid(GenerateUUID()), @@ -360,15 +144,15 @@ std::string DocNode::content_hash() const { return _content_hash; } -DocNode::Embedding& DocNode::embedding() { +DocNode::EmbeddingVec& DocNode::embedding() { return _embedding; } -const DocNode::Embedding& DocNode::embedding() const { +const DocNode::EmbeddingVec& DocNode::embedding() const { return _embedding; } -void DocNode::set_embedding(const Embedding& embed) { +void DocNode::set_embedding(const EmbeddingVec& embed) { std::lock_guard lock(_embedding_mutex); _embedding = embed; } @@ -387,9 +171,9 @@ std::vector DocNode::has_missing_embedding(const std::vector& embed) { - Embedding generated; - const std::string input = get_text_with_metadata(MetadataMode::Embed); +void DocNode::do_embedding(const std::unordered_map& embed) { + EmbeddingVec generated; + const std::string input = get_text_with_metadata(MetadataMode::EMBED); for (const auto& item : embed) { generated[item.first] = item.second(input, ""); } @@ -576,19 +360,19 @@ std::size_t DocNode::hash() const { } std::string DocNode::get_metadata_str(MetadataMode mode) const { - if (mode == MetadataMode::None) { + if (mode == MetadataMode::NONE) { return ""; } std::set keys; for (const auto& item : _metadata) { keys.insert(item.first); } - if (mode == MetadataMode::Llm) { + if (mode == MetadataMode::LLM) { const auto excluded = excluded_llm_metadata_keys(); for (const auto& key : excluded) { keys.erase(key); } - } else if (mode == MetadataMode::Embed) { + } else if (mode == MetadataMode::EMBED) { const auto excluded = excluded_embed_metadata_keys(); for (const auto& key : excluded) { keys.erase(key); @@ -611,8 +395,8 @@ std::string DocNode::get_metadata_str(MetadataMode mode) const { } std::string DocNode::get_content(MetadataMode mode) const { - if (mode == MetadataMode::Llm) { - return get_text_with_metadata(MetadataMode::Llm); + if (mode == MetadataMode::LLM) { + return get_text_with_metadata(MetadataMode::LLM); } return get_text_with_metadata(mode); } @@ -668,7 +452,7 @@ const std::string& QADocNode::answer() const { } std::string QADocNode::get_text_with_metadata(MetadataMode mode) const { - if (mode == MetadataMode::Llm) { + if (mode == MetadataMode::LLM) { std::ostringstream oss; oss << "query:\n" << _text << "\nanswer\n" << _answer; return oss.str(); @@ -695,7 +479,7 @@ const std::string& ImageDocNode::image_path() const { } std::string ImageDocNode::get_content(MetadataMode mode) const { - if (mode == MetadataMode::Embed) { + if (mode == MetadataMode::EMBED) { std::string file_bytes; if (!ReadFileBinary(_image_path, &file_bytes)) { return ""; @@ -710,9 +494,9 @@ std::string ImageDocNode::get_content(MetadataMode mode) const { return _image_path; } -void ImageDocNode::do_embedding(const std::unordered_map& embed) { - Embedding generated; - const std::string input = get_content(MetadataMode::Embed); +void ImageDocNode::do_embedding(const std::unordered_map& embed) { + EmbeddingVec generated; + const std::string input = get_content(MetadataMode::EMBED); for (const auto& item : embed) { generated[item.first] = item.second(input, _modality); } diff --git a/csrc/src/utils.cpp b/csrc/src/utils.cpp new file mode 100644 index 000000000..e69de29bb From 8e7e017504248d97b9343b5b37dbf53dc48f31c0 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 23 Jan 2026 15:35:44 +0800 Subject: [PATCH 06/40] save --- csrc/CMakeLists.txt | 14 +- .../{doc.cpp => export_add_doc_str.cpp} | 2 +- csrc/binding/export_doc_node.cpp | 117 +++++++++++++++ csrc/binding/lazyllm.cpp | 14 +- csrc/binding/lazyllm.hpp | 7 +- csrc/bridge/bridge.cpp | 1 + csrc/bridge/document_store.hpp | 40 ++++++ csrc/cmake/tests.cmake | 2 +- csrc/include/doc_node.hpp | 129 +++++++++-------- csrc/scripts/config_cmake.sh | 6 + csrc/src/doc_node.cpp | 136 ++---------------- lazyllm/tools/rag/doc_node.py | 4 +- lazyllm/tools/rag/utils.py | 7 +- 13 files changed, 283 insertions(+), 196 deletions(-) rename csrc/binding/{doc.cpp => export_add_doc_str.cpp} (97%) create mode 100644 csrc/binding/export_doc_node.cpp create mode 100644 csrc/bridge/bridge.cpp create mode 100644 csrc/bridge/document_store.hpp create mode 100644 csrc/scripts/config_cmake.sh diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 58e88040d..072771e18 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -8,13 +8,23 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) find_package(pybind11 CONFIG REQUIRED) +# Config lazyllm_bridge lib which defines tmp classes to bridge python and cpp. +file(GLOB_RECURSE LAZYLLM_BRIDGE_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/bridge/*.cpp") +add_library(lazyllm_bridge STATIC ${LAZYLLM_BRIDGE_SOURCES}) +target_include_directories(lazyllm_bridge PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/bridge) +target_link_libraries(lazyllm_bridge PUBLIC pybind11::headers Python3::Python) + # Config lazyllm_core lib with pure cpp code. -file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") +file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_link_libraries(lazyllm_core PUBLIC lazyllm_bridge) # Config lazyllm_cpp lib with binding infomations. -set(LAZYLLM_BINDING_SOURCES binding/lazyllm.cpp binding/doc.cpp) +file(GLOB_RECURSE LAZYLLM_BINDING_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/binding/*.cpp") set(INTERFACE_TARGET_NAME lazyllm_cpp) pybind11_add_module(${INTERFACE_TARGET_NAME} ${LAZYLLM_BINDING_SOURCES}) target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core) diff --git a/csrc/binding/doc.cpp b/csrc/binding/export_add_doc_str.cpp similarity index 97% rename from csrc/binding/doc.cpp rename to csrc/binding/export_add_doc_str.cpp index 587315a2a..4ef0fd53b 100644 --- a/csrc/binding/doc.cpp +++ b/csrc/binding/export_add_doc_str.cpp @@ -28,6 +28,6 @@ void addDocStr(py::object obj, std::string docs) { } } -void exportDoc(py::module& m) { +void exportAddDocStr(py::module& m) { m.def("add_doc", &addDocStr, "Add docstring to a function or method", py::arg("obj"), py::arg("docs")); } diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp new file mode 100644 index 000000000..c74369cf7 --- /dev/null +++ b/csrc/binding/export_doc_node.cpp @@ -0,0 +1,117 @@ +#include "lazyllm.hpp" +#include "document_store.hpp" +#include "doc_node.hpp" + +namespace py = pybind11; + +lazyllm::DocNode init( + std::optional uid, + std::optional>> content, + std::optional group, + std::optional embedding, + std::optional> parent, + py::object store, + std::optional>> node_groups, + std::optional metadata, + std::optional global_metadata, + std::optional text +) { + lazyllm::DocNode node; + + if (content && text) { + throw std::invalid_argument("`text` and `content` cannot be set at the same time."); + } + + if (uid) node.set_uid(uid.value()); + else node.generateUUID(); + + if (content) { + if (auto p = std::get_if(&*content)) + else node.set_text(std::any_cast(content.value)); + } + else node.set_text(*text); + + if (uid.has_value() && !uid->empty()) { + node.set_uid(*uid); + } else { + node.generateUUID(); + } + + if (group.has_value()) { + node.set_group(*group); + } + + if (!content.is_none()) { + if (py::isinstance(content)) { + node.set_content(content.cast()); + } else if (py::isinstance(content) || py::isinstance(content)) { + node.set_content(content.cast>()); + } else { + throw std::invalid_argument("`content` must be a str or list[str]."); + } + } else if (text.has_value()) { + node.set_text(*text); + } else { + node.set_text(""); + } + + if (embedding.has_value()) { + node.set_embedding(*embedding); + } + + if (!metadata.is_none()) { + if (!py::isinstance(metadata)) { + throw std::invalid_argument("`metadata` must be a dict."); + } + node.set_metadata(dict_to_metadata(metadata.cast())); + } + + if (!global_metadata.is_none()) { + if (!py::isinstance(global_metadata)) { + throw std::invalid_argument("`global_metadata` must be a dict."); + } + node.set_global_metadata(dict_to_metadata(global_metadata.cast())); + } + + if (!parent.is_none()) { + if (py::isinstance(parent)) { + node.set_parent(parent.cast>()); + } else if (py::isinstance(parent)) { + throw std::invalid_argument("`parent` as str is not supported in C++ binding."); + } else { + throw std::invalid_argument("`parent` must be a DocNode or None."); + } + } + + std::shared_ptr store_bridge; + if (!store.is_none()) { + py::object store_obj = store; + store_bridge = std::make_shared(store_obj); + } + node.set_store(std::move(store_bridge)); + + (void)node_groups; + + return node; +} + +void exportDocNode(py::module& m) { + py::class_(m, "DocNode") + .def(py::init(&init), + py::kw_only(), + py::arg("uid") = py::none(), + py::arg("content") = py::none(), + py::arg("group") = py::none(), + py::arg("embedding") = py::none(), + py::arg("parent") = py::none(), + py::arg("store") = py::none(), + py::arg("node_groups") = py::none(), + py::arg("metadata") = py::none(), + py::arg("global_metadata") = py::none(), + py::arg("text") = py::none() + ) + .def_property_readonly("uid", &lazyllm::DocNode::uid) + .def("set_text", &lazyllm::DocNode::set_text, py::arg("text")) + .def("get_text", &lazyllm::DocNode::get_text); +} diff --git a/csrc/binding/lazyllm.cpp b/csrc/binding/lazyllm.cpp index 2789e8e1c..b5282cdc7 100644 --- a/csrc/binding/lazyllm.cpp +++ b/csrc/binding/lazyllm.cpp @@ -1,21 +1,19 @@ #include "lazyllm.hpp" -#include "doc_node.h" +#include "document_store.hpp" +#include "doc_node.hpp" + +#include namespace py = pybind11; PYBIND11_MODULE(lazyllm_cpp, m) { m.doc() = "LazyLLM CPP Module."; - exportDoc(m); + exportAddDocStr(m); // prevent document generation py::options options; options.disable_function_signatures(); - // DocNode - py::class_(m, "DocNode") - .def(py::init<>()) - .def(py::init(), py::arg("text")) - .def("set_text", &lazyllm::DocNode::set_text, py::arg("text")) - .def("get_text", &lazyllm::DocNode::get_text); + exportDocNode(m); } diff --git a/csrc/binding/lazyllm.hpp b/csrc/binding/lazyllm.hpp index d248f6677..537514ef6 100644 --- a/csrc/binding/lazyllm.hpp +++ b/csrc/binding/lazyllm.hpp @@ -1,6 +1,11 @@ #pragma once +#include +#include + +#include #include #include -void exportDoc(pybind11::module& m); +void exportAddDocStr(pybind11::module& m); +void exportDocNode(pybind11::module& m); diff --git a/csrc/bridge/bridge.cpp b/csrc/bridge/bridge.cpp new file mode 100644 index 000000000..3e54ade0a --- /dev/null +++ b/csrc/bridge/bridge.cpp @@ -0,0 +1 @@ +#include "document_store.hpp" diff --git a/csrc/bridge/document_store.hpp b/csrc/bridge/document_store.hpp new file mode 100644 index 000000000..1c9e321c1 --- /dev/null +++ b/csrc/bridge/document_store.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include +#include + +namespace lazyllm { + +class DocumentStore { +public: + DocumentStore() = delete; + explicit DocumentStore(pybind11::object &store) : _py_store(store) {} + + bool is_group_active(const std::string& grp) const { + pybind11::gil_scoped_acquire gil; + pybind11::object fn = _py_store.attr("is_group_active"); + return fn(grp).cast(); + } + + pybind11::list get_nodes( + const std::string& group_names + const std::string& kb_id, + const std::vector& doc_ids + ) const { + pybind11::gil_scoped_acquire gil; + pybind11::object fn = _py_store.attr("get_nodes"); + pybind11::object result = fn( + pybind11::arg("group_name") = group_name, + pybind11::arg("kb_id") = kb_id, + pybind11::arg("doc_ids") = doc_ids); + return result.cast(); + } + +private: + pybind11::object _py_store; +}; + +} // namespace lazyllm diff --git a/csrc/cmake/tests.cmake b/csrc/cmake/tests.cmake index e495d3cbf..385d526e9 100644 --- a/csrc/cmake/tests.cmake +++ b/csrc/cmake/tests.cmake @@ -2,7 +2,7 @@ include(FetchContent) FetchContent_Declare( googletest - URL https://codeload.github.com/google/googletest/zip/refs/tags/release-1.17.0 + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip ) # Fix gtest version to maintain C++11 compatibility. set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) diff --git a/csrc/include/doc_node.hpp b/csrc/include/doc_node.hpp index 109f245bf..e36e7dfc9 100644 --- a/csrc/include/doc_node.hpp +++ b/csrc/include/doc_node.hpp @@ -7,50 +7,87 @@ #include #include #include +#include +#include + +#include "utils.hpp" +#include "document_store.hpp" namespace lazyllm { enum class MetadataMode { ALL, EMBED, LLM, NONE }; class DocNode { -protected: - using Metadata = std::unordered_map; - using Children = std::unordered_map>; - using EmbeddingFun = std::function(const std::string&, const std::string&)>; - using EmbeddingVec = std::unordered_map>; - public: - DocNode( - uid: Optional[str] = None, - content: Optional[Union[str, List[Any]]] = None, - group: Optional[str] = None, - embedding: Optional[Dict[str,List[float]]] = None, - parent: Optional[Union[str, 'DocNode']] = None, - store=None, - node_groups: Optional[Dict[str, Dict]] = None, - metadata: Optional[Dict[str, Any]] = None, - global_metadata: Optional[Dict[str, Any]] = None, - text: Optional[str] = None - ); - explicit DocNode(const std::string& text); - DocNode(const DocNode& other); - DocNode& operator=(const DocNode& other); - virtual ~DocNode() = default; + using Metadata = std::unordered_map; + using Children = std::unordered_map>; + using EmbeddingFun = std::function(const std::string&, const std::string&)>; + using EmbeddingVec = std::unordered_map>; - const std::string& uid() const; - const std::string& group() const; - void set_group(const std::string& group); + mutable std::set embedding_state = {}; + double relevance_score = .0; + double similarity_score = .0; - bool content_is_list() const; - const std::vector& content_list() const; - const std::string& content_text() const; +private: + std::string _text = ""; + std::string _uid = ""; + std::string _group_name = ""; + std::string _parent_group_name = ""; + std::vector _content; + mutable std::string _content_hash = ""; + EmbeddingVec _embedding; + Metadata _metadata; + Metadata _global_metadata; + std::shared_ptr _store; + std::vector _excluded_embed_metadata_keys; + std::vector _excluded_llm_metadata_keys; + DocNode* _p_parent_node = nullptr; + Children _children; + bool _children_loaded = false; - void set_content(const std::string& text); - void set_content(const std::vector& lines); +public: + DocNode() = default; + explicit DocNode( + const std::string &text = "", + const std::vector &content = {}, + const std::string &group_name = "", + const std::string &parent_group_name = "", + const DocNode *p_parent_node = nullptr, + const std::shared_ptr p_store = nullptr, + const EmbeddingVec &embedding_vec = {}, + const Metadata &metadata = {}, + const Metadata &global_metadata = {}, + const std::string &uid = "" + ) : + _text(text), + _content(content), + _uid(uid), + _group_name(group_name), + _parent_group_name(parent_group_name), + _embedding(embedding_vec), + _metadata(metadata), + _global_metadata(global_metadata), + _store(p_store), + _parent_node(p_parent_node) {} + DocNode(const DocNode&) = default; + DocNode& operator=(const DocNode&) = default; + virtual ~DocNode() = default; - void set_text(const std::string& text); + // Getter and Setter + const std::string& uid() const { return _uid; } + void set_uid(const std::string& uid) { _uid = uid; } + const std::string& group() const { return _group; } + void set_group(const std::string& group) { _group = group; } + void set_text(const std::string& text){ + _text = text; + _content.clear(); + _content_hash = ""; + } + void set_content(const std::vector& content); const std::string& get_text() const; virtual std::string get_text_with_metadata(MetadataMode mode) const; + void set_store(std::shared_ptr store); + const std::shared_ptr& store() const; std::string content_hash() const; @@ -63,9 +100,9 @@ class DocNode { void set_embedding_value(const std::string& key, const std::vector& value); void check_embedding_state(const std::string& embed_key) const; - DocNode* parent(); - const DocNode* parent() const; - void set_parent(DocNode* parent); + std::shared_ptr parent(); + const std::shared_ptr parent() const; + void set_parent(std::shared_ptr parent); Children& children(); const Children& children() const; @@ -110,29 +147,9 @@ class DocNode { double relevance_score() const; double similarity_score() const; -protected: - void invalidate_content_hash(); + void generateUUID() { _uid = GenerateUUID(); } + - std::string _uid; - std::string _group; - std::string _text; - bool _content_is_list; - std::vector _content_list; - EmbeddingVec _embedding; - Metadata _metadata; - Metadata _global_metadata; - std::vector _excluded_embed_metadata_keys; - std::vector _excluded_llm_metadata_keys; - DocNode* _parent; - Children _children; - bool _children_loaded; - mutable std::set _embedding_state; - mutable std::string _content_hash; - mutable bool _content_hash_dirty; - double _relevance_score; - bool _has_relevance_score; - double _similarity_score; - bool _has_similarity_score; }; class QADocNode : public DocNode { diff --git a/csrc/scripts/config_cmake.sh b/csrc/scripts/config_cmake.sh new file mode 100644 index 000000000..a02b81209 --- /dev/null +++ b/csrc/scripts/config_cmake.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +set -euo pipefail + +cmake -S csrc -B build \ + -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ + -DCMAKE_BUILD_TYPE=Debug diff --git a/csrc/src/doc_node.cpp b/csrc/src/doc_node.cpp index abcf20252..65b84c7fd 100644 --- a/csrc/src/doc_node.cpp +++ b/csrc/src/doc_node.cpp @@ -1,6 +1,3 @@ -#include "doc_node.hpp" -#include "utils.hpp" - #include #include #include @@ -10,116 +7,23 @@ #include #include -namespace lazyllm { - -DocNode::DocNode() - : _uid(GenerateUUID()), - _group(), - _text(), - _content_is_list(false), - _parent(nullptr), - _children_loaded(false), - _content_hash(), - _content_hash_dirty(true), - _relevance_score(0.0), - _has_relevance_score(false), - _similarity_score(0.0), - _has_similarity_score(false) {} - -DocNode::DocNode(const std::string& text) : DocNode() { - set_text(text); -} - -DocNode::DocNode(const DocNode& other) - : _uid(other._uid), - _group(other._group), - _text(other._text), - _content_is_list(other._content_is_list), - _content_list(other._content_list), - _embedding(other._embedding), - _metadata(other._metadata), - _global_metadata(other._global_metadata), - _excluded_embed_metadata_keys(other._excluded_embed_metadata_keys), - _excluded_llm_metadata_keys(other._excluded_llm_metadata_keys), - _parent(other._parent), - _children(other._children), - _children_loaded(other._children_loaded), - _embedding_state(other._embedding_state), - _content_hash(other._content_hash), - _content_hash_dirty(other._content_hash_dirty), - _relevance_score(other._relevance_score), - _has_relevance_score(other._has_relevance_score), - _similarity_score(other._similarity_score), - _has_similarity_score(other._has_similarity_score) {} - -DocNode& DocNode::operator=(const DocNode& other) { - if (this == &other) { - return *this; - } - _uid = other._uid; - _group = other._group; - _text = other._text; - _content_is_list = other._content_is_list; - _content_list = other._content_list; - _embedding = other._embedding; - _metadata = other._metadata; - _global_metadata = other._global_metadata; - _excluded_embed_metadata_keys = other._excluded_embed_metadata_keys; - _excluded_llm_metadata_keys = other._excluded_llm_metadata_keys; - _parent = other._parent; - _children = other._children; - _children_loaded = other._children_loaded; - _embedding_state = other._embedding_state; - _content_hash = other._content_hash; - _content_hash_dirty = other._content_hash_dirty; - _relevance_score = other._relevance_score; - _has_relevance_score = other._has_relevance_score; - _similarity_score = other._similarity_score; - _has_similarity_score = other._has_similarity_score; - return *this; -} - -const std::string& DocNode::uid() const { - return _uid; -} - -const std::string& DocNode::group() const { - return _group; -} - -void DocNode::set_group(const std::string& group) { - _group = group; -} - -bool DocNode::content_is_list() const { - return _content_is_list; -} +#include "doc_node.hpp" +#include "utils.hpp" -const std::vector& DocNode::content_list() const { - return _content_list; -} +namespace lazyllm { const std::string& DocNode::content_text() const { return _text; } -void DocNode::set_content(const std::string& text) { - set_text(text); -} - void DocNode::set_content(const std::vector& lines) { _content_is_list = true; - _content_list = lines; + _content = lines; _text = JoinLines(lines); - invalidate_content_hash(); + content_hash(); } -void DocNode::set_text(const std::string& text) { - _text = text; - _content_is_list = false; - _content_list.clear(); - invalidate_content_hash(); -} + const std::string& DocNode::get_text() const { return _text; @@ -136,11 +40,16 @@ std::string DocNode::get_text_with_metadata(MetadataMode mode) const { return metadata_str + "\n\n" + _text; } +void DocNode::set_store(std::shared_ptr store) { + _store = std::move(store); +} + +const std::shared_ptr& DocNode::store() const { + return _store; +} + std::string DocNode::content_hash() const { - if (_content_hash_dirty) { - _content_hash = Sha256Hex(_text); - _content_hash_dirty = false; - } + _content_hash = Sha256Hex(_text); return _content_hash; } @@ -191,7 +100,6 @@ void DocNode::set_embedding_value(const std::string& key, const std::vector lock(_embedding_mutex); if (_embedding.find(embed_key) != _embedding.end()) { _embedding_state.erase(embed_key); break; @@ -404,25 +312,15 @@ std::string DocNode::get_content(MetadataMode mode) const { DocNode DocNode::with_score(double score) const { DocNode node(*this); node._relevance_score = score; - node._has_relevance_score = true; return node; } DocNode DocNode::with_sim_score(double score) const { DocNode node(*this); node._similarity_score = score; - node._has_similarity_score = true; return node; } -bool DocNode::has_relevance_score() const { - return _has_relevance_score; -} - -bool DocNode::has_similarity_score() const { - return _has_similarity_score; -} - double DocNode::relevance_score() const { return _relevance_score; } @@ -431,10 +329,6 @@ double DocNode::similarity_score() const { return _similarity_score; } -void DocNode::invalidate_content_hash() { - _content_hash_dirty = true; -} - QADocNode::QADocNode(const std::string& query, const std::string& answer) : DocNode(query), _answer(Trim(answer)) {} diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 8b50fdd8f..b50dc2f3d 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -50,7 +50,7 @@ def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[ self._store = store self._node_groups: Dict[str, Dict] = node_groups or {} self._lock = threading.Lock() - self._embedding_state = set() + self.embedding_state = set() self.relevance_score = None self.similarity_score = None self._content_hash: Optional[str] = None @@ -247,7 +247,7 @@ def check_embedding_state(self, embed_key: str) -> None: while True: with self._lock: if not self.has_missing_embedding(embed_key): - self._embedding_state.discard(embed_key) + self.embedding_state.discard(embed_key) break time.sleep(1) diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index c6b0eebe9..a02cfccd0 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -871,8 +871,7 @@ def parallel_do_embedding(embed: Dict[str, Callable], embed_keys: Optional[Union modified_nodes.append(node) for k in miss: tasks_by_key[k].append(node) - if hasattr(node, '_embedding_state'): - node._embedding_state.add(k) + node.embedding_state.add(k) if not tasks_by_key: return [] @@ -907,9 +906,9 @@ def _process_key(k: str, knodes: List[DocNode]): except Exception as e: lazyllm.LOG.error(f'[LazyLLM - parallel_do_embedding][{k}] error: {e}') for n in knodes: - if hasattr(n, '_embedding_state') and k in n._embedding_state: + if k in n.embedding_state: with n._lock: - n._embedding_state.remove(k) + n.embedding_state.remove(k) raise e with ThreadPoolExecutor(max_workers=min(max_workers, len(tasks_by_key))) as ex: From 7ea108edcc2fe41f7fe9378f78e8366aee0b32cc Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 28 Jan 2026 11:56:29 +0800 Subject: [PATCH 07/40] undo workflow fix --- .github/workflows/publish_release.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/publish_release.yml b/.github/workflows/publish_release.yml index 699eabe42..ed5840fd7 100644 --- a/.github/workflows/publish_release.yml +++ b/.github/workflows/publish_release.yml @@ -187,12 +187,6 @@ jobs: with: name: repo-with-docs path: ./repo_artifact - - - name: Install Python dev headers (Ubuntu only) - if: startsWith(matrix.os, 'ubuntu') - run: | - sudo apt-get update - sudo apt-get install -y python3-dev - name: Extract repo-with-docs run: | From f0f6657b5c3e1e0d4e4f3d8bb65d57c8eb846beb Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 30 Jan 2026 15:09:40 +0800 Subject: [PATCH 08/40] refactor --- .gitignore | 1 + csrc/CMakeLists.txt | 34 ++- csrc/{include => }/README.md | 0 .../{bridge/bridge.cpp => adaptor/adptor.cpp} | 0 csrc/adaptor/document_store.hpp | 105 ++++++++ csrc/binding/export_doc_node.cpp | 80 ++++--- csrc/bridge/document_store.hpp | 40 ---- csrc/cmake/tests.cmake | 2 - csrc/core/include/doc_node.hpp | 163 +++++++++++++ csrc/core/include/utils.hpp | 132 +++++++++++ csrc/{ => core}/src/doc_node.cpp | 4 +- csrc/{ => core}/src/utils.cpp | 0 csrc/include/doc_node.hpp | 183 -------------- csrc/include/utils.hpp | 224 ------------------ csrc/src/README.md | 0 15 files changed, 475 insertions(+), 493 deletions(-) rename csrc/{include => }/README.md (100%) rename csrc/{bridge/bridge.cpp => adaptor/adptor.cpp} (100%) create mode 100644 csrc/adaptor/document_store.hpp delete mode 100644 csrc/bridge/document_store.hpp create mode 100644 csrc/core/include/doc_node.hpp create mode 100644 csrc/core/include/utils.hpp rename csrc/{ => core}/src/doc_node.cpp (98%) rename csrc/{ => core}/src/utils.cpp (100%) delete mode 100644 csrc/include/doc_node.hpp delete mode 100644 csrc/include/utils.hpp delete mode 100644 csrc/src/README.md diff --git a/.gitignore b/.gitignore index 0e34cdd67..c0e9cec43 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ test/ dist/ tmp/ build +.cache/ *.lock *.db mkdocs.yml diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 072771e18..3ee930048 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -4,30 +4,46 @@ project(LazyLLMCPP LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +add_compile_options( + -Werror + -Wshadow +) +include(FetchContent) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) find_package(pybind11 CONFIG REQUIRED) -# Config lazyllm_bridge lib which defines tmp classes to bridge python and cpp. -file(GLOB_RECURSE LAZYLLM_BRIDGE_SOURCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_SOURCE_DIR}/bridge/*.cpp") -add_library(lazyllm_bridge STATIC ${LAZYLLM_BRIDGE_SOURCES}) -target_include_directories(lazyllm_bridge PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/bridge) -target_link_libraries(lazyllm_bridge PUBLIC pybind11::headers Python3::Python) +find_package(xxHash QUIET) +if (NOT TARGET xxhash) + FetchContent_Declare( + xxhash + GIT_REPOSITORY https://github.com/Cyan4973/xxHash.git + GIT_TAG v0.8.2 + ) + FetchContent_Populate(xxhash) + add_subdirectory(${xxhash_SOURCE_DIR}/cmake_unofficial ${xxhash_BINARY_DIR}) +endif() # Config lazyllm_core lib with pure cpp code. file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + "${CMAKE_CURRENT_SOURCE_DIR}/core/src/*.cpp") add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) -target_link_libraries(lazyllm_core PUBLIC lazyllm_bridge) +target_link_libraries(lazyllm_core PUBLIC xxhash) + +# Config lazyllm_adaptor lib which maintains callback invocations. +file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/adaptor/*.cpp") +add_library(lazyllm_adaptor STATIC ${LAZYLLM_ADAPTOR_SOURCES}) +target_include_directories(lazyllm_adaptor PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/adaptor) +target_link_libraries(lazyllm_adaptor PUBLIC pybind11::headers Python3::Python lazyllm_core) # Config lazyllm_cpp lib with binding infomations. file(GLOB_RECURSE LAZYLLM_BINDING_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/binding/*.cpp") set(INTERFACE_TARGET_NAME lazyllm_cpp) pybind11_add_module(${INTERFACE_TARGET_NAME} ${LAZYLLM_BINDING_SOURCES}) -target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core) +target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core lazyllm_adaptor) if (CMAKE_BUILD_TYPE STREQUAL "Debug") # SHOW_SYMBOL diff --git a/csrc/include/README.md b/csrc/README.md similarity index 100% rename from csrc/include/README.md rename to csrc/README.md diff --git a/csrc/bridge/bridge.cpp b/csrc/adaptor/adptor.cpp similarity index 100% rename from csrc/bridge/bridge.cpp rename to csrc/adaptor/adptor.cpp diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp new file mode 100644 index 000000000..c12d1800d --- /dev/null +++ b/csrc/adaptor/document_store.hpp @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace lazyllm { + +struct NodeGroup { + enum class Type { + ORIGINAL, CHUNK, SUMMARY, IMAGE_INFO, QUESTION_ANSWER, OTHER + }; + std::string parent; + std::string display_name; + Type type; + NodeGroup( + const std::string& parent, + const std::string& display_name, + const Type& type = Type::ORIGINAL) : + parent(parent), display_name(display_name), type(type) {} +}; + +class DocumentStore { +public: + // RAG system metadata keys + static constexpr std::string_view RAG_KB_ID_KEY = "kb_id"; + static constexpr std::string_view RAG_DOC_ID_KEY = "docid"; + + DocumentStore() = delete; + explicit DocumentStore( + const pybind11::object &store, + const std::unordered_map &map) : + _py_store(store), _node_groups_map(map) {} + + // Cache-aware factory to avoid rebuilding wrappers for the same Python store. + static std::shared_ptr from_store( + const pybind11::object &store, const std::unordered_map &map) { + if (store.is_none()) return nullptr; + + pybind11::gil_scoped_acquire gil; + PyObject *key = store.ptr(); + auto &cache = store_cache(); + auto it = cache.find(key); + if (it != cache.end()) { + if (auto existing = it->second.lock()) + return existing; + } + auto created = std::shared_ptr(new DocumentStore(store, map)); + cache[key] = created; + return created; + } + + bool is_group_active(const std::string& grp) const { + pybind11::gil_scoped_acquire gil; + pybind11::object fn = _py_store.attr("is_group_active"); + return fn(grp).cast(); + } + + pybind11::object get_node( + const std::string& group_name, + const std::string& uid, + const std::string& kb_id + ) const { + pybind11::gil_scoped_acquire gil; + pybind11::object fn = _py_store.attr("get_nodes"); + pybind11::object result = fn( + pybind11::arg("group_name") = group_name, + pybind11::arg("uids") = std::vector({uid}), + pybind11::arg("kb_id") = kb_id, + pybind11::arg("display") = true); + return result.cast()[0]; + } + + std::vector get_nodes( + const std::string& group_name, + const std::string& kb_id, + const std::string& doc_id + ) const { + pybind11::gil_scoped_acquire gil; + pybind11::object fn = _py_store.attr("get_nodes"); + pybind11::object result = fn( + pybind11::arg("group_name") = group_name, + pybind11::arg("kb_id") = kb_id, + pybind11::arg("doc_ids") = std::vector({doc_id})); + return result.cast>(); + } + +private: + // Keep the underlying Python store object alive for callback invocations. + pybind11::object _py_store; + std::unordered_map _node_groups_map; + + // Cache by Python object identity to ensure one wrapper per store instance. + static std::unordered_map> &store_cache() { + static std::unordered_map> cache; + return cache; + } +}; + +} // namespace lazyllm diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index c74369cf7..aaf5051c3 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -1,12 +1,16 @@ +#include +#include + #include "lazyllm.hpp" #include "document_store.hpp" #include "doc_node.hpp" +#include "utils.hpp" namespace py = pybind11; lazyllm::DocNode init( std::optional uid, - std::optional>> content, + std::optional>> content, std::optional group, std::optional embedding, std::optional> parent, @@ -17,44 +21,53 @@ lazyllm::DocNode init( std::optional global_metadata, std::optional text ) { - lazyllm::DocNode node; - if (content && text) { throw std::invalid_argument("`text` and `content` cannot be set at the same time."); } - if (uid) node.set_uid(uid.value()); - else node.generateUUID(); + // Find parent node. + lazyllm::DocNode* p_parent_node = nullptr; + std::shared_ptr store_bridge = nullptr; + + // Build node groups map. + // Usually, parent + store + node_groups are not None at the same time. + if (parent && !store.is_none() && node_groups && global_metadata && group) { + std::unordered_map node_groups_map; + node_groups_map.reserve(node_groups->size()); + for (const auto &entry : *node_groups) { + auto &group_key = entry.first; + auto &group_dict = entry.second; + node_groups_map.emplace(group_key, lazyllm::NodeGroup( + std::any_cast(group_dict.at(std::string("parent"))), + std::any_cast(group_dict.at(std::string("display_name"))) + )); + } + auto store_bridge = lazyllm::DocumentStore::from_store(store, node_groups_map); - if (content) { - if (auto p = std::get_if(&*content)) - else node.set_text(std::any_cast(content.value)); - } - else node.set_text(*text); + auto kb_id = std::any_cast((*global_metadata)[ + std::string(lazyllm::DocumentStore::RAG_KB_ID_KEY)]); + auto doc_id = std::any_cast((*global_metadata)[ + std::string(lazyllm::DocumentStore::RAG_DOC_ID_KEY)]); - if (uid.has_value() && !uid->empty()) { - node.set_uid(*uid); - } else { - node.generateUUID(); + if (std::holds_alternative(*parent)) { + const std::string& parent_uid = std::get(*parent); + p_parent_node = store_bridge->get_node(parent_uid, *group, kb_id).cast(); + } + else + p_parent_node = std::get(*parent).cast(); } - if (group.has_value()) { - node.set_group(*group); - } + lazyllm::DocNode node( + content.value_or(text.value_or(std::string(""))), + uid.value_or(lazyllm::GenerateUUID()), + group.value_or(""), + p_parent_node, + store_bridge, + embedding.value_or(lazyllm::DocNode::EmbeddingVec()), + metadata.value_or(lazyllm::DocNode::Metadata()), + global_metadata.value_or(lazyllm::DocNode::Metadata()) + ); - if (!content.is_none()) { - if (py::isinstance(content)) { - node.set_content(content.cast()); - } else if (py::isinstance(content) || py::isinstance(content)) { - node.set_content(content.cast>()); - } else { - throw std::invalid_argument("`content` must be a str or list[str]."); - } - } else if (text.has_value()) { - node.set_text(*text); - } else { - node.set_text(""); - } if (embedding.has_value()) { node.set_embedding(*embedding); @@ -87,7 +100,7 @@ lazyllm::DocNode init( std::shared_ptr store_bridge; if (!store.is_none()) { py::object store_obj = store; - store_bridge = std::make_shared(store_obj); + store_bridge = lazyllm::DocumentStore::from_store(store_obj, node_groups_map); } node.set_store(std::move(store_bridge)); @@ -109,9 +122,12 @@ void exportDocNode(py::module& m) { py::arg("node_groups") = py::none(), py::arg("metadata") = py::none(), py::arg("global_metadata") = py::none(), - py::arg("text") = py::none() + py::arg("text") = py::none(), + + py::keep_alive<1, 7>() // Keep store alive ) .def_property_readonly("uid", &lazyllm::DocNode::uid) + number .def("set_text", &lazyllm::DocNode::set_text, py::arg("text")) .def("get_text", &lazyllm::DocNode::get_text); } diff --git a/csrc/bridge/document_store.hpp b/csrc/bridge/document_store.hpp deleted file mode 100644 index 1c9e321c1..000000000 --- a/csrc/bridge/document_store.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -namespace lazyllm { - -class DocumentStore { -public: - DocumentStore() = delete; - explicit DocumentStore(pybind11::object &store) : _py_store(store) {} - - bool is_group_active(const std::string& grp) const { - pybind11::gil_scoped_acquire gil; - pybind11::object fn = _py_store.attr("is_group_active"); - return fn(grp).cast(); - } - - pybind11::list get_nodes( - const std::string& group_names - const std::string& kb_id, - const std::vector& doc_ids - ) const { - pybind11::gil_scoped_acquire gil; - pybind11::object fn = _py_store.attr("get_nodes"); - pybind11::object result = fn( - pybind11::arg("group_name") = group_name, - pybind11::arg("kb_id") = kb_id, - pybind11::arg("doc_ids") = doc_ids); - return result.cast(); - } - -private: - pybind11::object _py_store; -}; - -} // namespace lazyllm diff --git a/csrc/cmake/tests.cmake b/csrc/cmake/tests.cmake index 385d526e9..88803297a 100644 --- a/csrc/cmake/tests.cmake +++ b/csrc/cmake/tests.cmake @@ -1,5 +1,3 @@ -include(FetchContent) - FetchContent_Declare( googletest URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp new file mode 100644 index 000000000..11597ca92 --- /dev/null +++ b/csrc/core/include/doc_node.hpp @@ -0,0 +1,163 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +namespace lazyllm { + +enum class MetadataMode { ALL, EMBED, LLM, NONE }; + +class DocNode { +public: + using Metadata = std::unordered_map; + using Children = std::unordered_map>; + using EmbeddingFun = std::function(const std::string&, const std::string&)>; + using EmbeddingVec = std::unordered_map>; + + DocNode* _p_parent_node = nullptr; +private: + std::shared_ptr _p_root_text = nullptr; + std::string_view _text_view; + std::string _group_name; + std::string _uid; + mutable size_t _text_hash = 0; + + Metadata _metadata; + std::shared_ptr _p_global_metadata; + std::vector _excluded_embed_metadata_keys; + std::vector _excluded_llm_metadata_keys; + + EmbeddingVec _embedding_vec; + mutable std::set _embedding_state = {}; + double _relevance_score = .0; + double _similarity_score = .0; + + Children _children; + +public: + DocNode() = delete; + explicit DocNode( + const std::string_view& text_view, + const std::string& group_name, + const std::string& uid = "", + const EmbeddingVec& embedding_vec = {}, + const Metadata& metadata = {}, + const std::shared_ptr& global_metadata = {}, + const std::shared_ptr& p_raw_text = nullptr + ) : + _text_view(text_view), + _group_name(group_name), + _uid(uid), + _metadata(metadata), + _p_global_metadata(global_metadata), + _embedding_vec(embedding_vec), + _p_root_text(p_raw_text) + { + if (uid.empty()) _uid = GenerateUUID(); + } + + DocNode(const DocNode&) = default; + DocNode& operator=(const DocNode&) = default; + virtual ~DocNode() = default; + + size_t evaluate_text_hash() const { + return static_cast(XXH64(_text_view.data(), _text_view.size(), 0)); + } + std::vector find_children(const std::vector& nodes) const { + std::vector children; + for (auto p_node : nodes) + if (p_node->get_parent_node() == _p_parent_node) + children.push_back(p_node); + return children; + } + + // Getter and Setter + const std::string& get_uid() const { return _uid; } + const std::string& get_group_name() const { return _group_name; } + const std::string_view& get_text_view() const { return _text_view; } + const std::string& get_text() const { return std::string(_text_view); } + void set_text(const std::string& text) { + _p_root_text = std::make_shared(text); + _text_view = *_p_root_text; + _text_hash = evaluate_text_hash(); + } + const size_t& get_text_hash() const { + if (_text_hash == 0) _text_hash = evaluate_text_hash(); + return _text_hash; + } + const DocNode* get_parent_node() { return _p_parent_node; } + void set_parent_node(DocNode* p_parent_node) { _p_parent_node = p_parent_node; } + + const std::vector& get_children() const { + return _children; + } + + const Metadata& get_metadata() const { return _metadata; } + void set_metadata(const Metadata& metadata) { _metadata = metadata; } + + std::shared_ptr global_metadata_ptr() { return _p_global_metadata; } + std::shared_ptr global_metadata_ptr() const { return _p_global_metadata; } + Metadata& global_metadata() { return *_p_global_metadata; } + const Metadata& global_metadata() const { return *_p_global_metadata; } + void set_global_metadata_ptr(std::shared_ptr p_global_metadata) { + _p_global_metadata = std::move(p_global_metadata); + } + void set_global_metadata(const Metadata& global_metadata) { + _p_global_metadata = std::make_shared(global_metadata); + } + + const std::vector& excluded_embed_metadata_keys() const { + return _excluded_embed_metadata_keys; + } + void set_excluded_embed_metadata_keys(const std::vector& keys) { + _excluded_embed_metadata_keys = keys; + } + + const std::vector& excluded_llm_metadata_keys() const { + return _excluded_llm_metadata_keys; + } + void set_excluded_llm_metadata_keys(const std::vector& keys) { + _excluded_llm_metadata_keys = keys; + } + + EmbeddingVec& embedding_vec() { return _embedding_vec; } + const EmbeddingVec& embedding_vec() const { return _embedding_vec; } + void set_embedding_vec(const EmbeddingVec& embedding_vec) { _embedding_vec = embedding_vec; } + + std::set& embedding_state() { return _embedding_state; } + const std::set& embedding_state() const { return _embedding_state; } + void set_embedding_state(const std::set& embedding_state) { + _embedding_state = embedding_state; + } + + double relevance_score() const { return _relevance_score; } + void set_relevance_score(double relevance_score) { _relevance_score = relevance_score; } + + double similarity_score() const { return _similarity_score; } + void set_similarity_score(double similarity_score) { _similarity_score = similarity_score; } + + const std::variant& parent_node() const { return _p_parent_node; } + void set_parent_node(const std::variant& parent_node) { + _p_parent_node = parent_node; + } + + Children& children() { return _children; } + const Children& children() const { return _children; } + void set_children(const Children& children) { _children = children; } + + bool children_loaded() const { return _children_loaded; } + void set_children_loaded(bool children_loaded) { _children_loaded = children_loaded; } +}; + +} // namespace lazyllm diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp new file mode 100644 index 000000000..76d50c6d1 --- /dev/null +++ b/csrc/core/include/utils.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace lazyllm { + +std::string JoinLines(const std::vector& lines) { + std::ostringstream oss; + for (size_t i = 0; i < lines.size(); ++i) { + if (i > 0) { + oss << "\n"; + } + oss << lines[i]; + } + return oss.str(); +} + +std::string Trim(const std::string& value) { + size_t start = 0; + while (start < value.size() && std::isspace(static_cast(value[start]))) { + ++start; + } + size_t end = value.size(); + while (end > start && std::isspace(static_cast(value[end - 1]))) { + --end; + } + return value.substr(start, end - start); +} + +std::string GenerateUUID() { + static const char HEX_CHAR[] = "0123456789abcdef"; + static const int SEGS[] = {8, 4, 4, 4, 12}; + + // Single static generator per thread. + static thread_local std::mt19937 GEN(std::random_device{}()); + static thread_local std::uniform_int_distribution DIST(0, 15); + + std::string out; + out.reserve(36); + for (int segLength : SEGS) { + for (int i = 0; i < segLength; ++i) + out.push_back(HEX_CHAR[DIST(GEN)]); + if (segLength < 12) + out.push_back('-'); + } + return out; +} + +std::string ToLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return value; +} + +std::string GetExtension(const std::string& path) { + const size_t dot = path.find_last_of('.'); + if (dot == std::string::npos || dot + 1 >= path.size()) { + return ""; + } + return ToLower(path.substr(dot + 1)); +} + +std::string ImageMimeType(const std::string& path) { + const std::string ext = GetExtension(path); + if (ext == "jpg" || ext == "jpeg" || ext == "jfif" || ext == "jpe") { + return "image/jpeg"; + } + if (ext == "png" || ext == "apng") { + return "image/png"; + } + if (ext == "gif") { + return "image/gif"; + } + if (ext == "bmp" || ext == "dib") { + return "image/bmp"; + } + if (ext == "tif" || ext == "tiff") { + return "image/tiff"; + } + if (ext == "webp") { + return "image/webp"; + } + if (ext == "ico") { + return "image/x-icon"; + } + if (ext == "icns") { + return "image/icns"; + } + return ""; +} + +bool ReadFileBinary(const std::string& path, std::string* out) { + std::ifstream file(path.c_str(), std::ios::binary); + if (!file) { + return false; + } + std::ostringstream buffer; + buffer << file.rdbuf(); + *out = buffer.str(); + return true; +} + +std::string Base64Encode(const std::string& data) { + static const char kBase64Chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string out; + out.reserve(((data.size() + 2) / 3) * 4); + int val = 0; + int valb = -6; + for (unsigned char c : data) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(kBase64Chars[(val >> valb) & 0x3F]); + valb -= 6; + } + } + if (valb > -6) { + out.push_back(kBase64Chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + while (out.size() % 4) { + out.push_back('='); + } + return out; +} + +} // namespace lazyllm diff --git a/csrc/src/doc_node.cpp b/csrc/core/src/doc_node.cpp similarity index 98% rename from csrc/src/doc_node.cpp rename to csrc/core/src/doc_node.cpp index 65b84c7fd..ca2d592f6 100644 --- a/csrc/src/doc_node.cpp +++ b/csrc/core/src/doc_node.cpp @@ -26,7 +26,7 @@ void DocNode::set_content(const std::vector& lines) { const std::string& DocNode::get_text() const { - return _text; + return _text_view; } std::string DocNode::get_text_with_metadata(MetadataMode mode) const { @@ -86,14 +86,12 @@ void DocNode::do_embedding(const std::unordered_map& for (const auto& item : embed) { generated[item.first] = item.second(input, ""); } - std::lock_guard lock(_embedding_mutex); for (const auto& item : generated) { _embedding[item.first] = item.second; } } void DocNode::set_embedding_value(const std::string& key, const std::vector& value) { - std::lock_guard lock(_embedding_mutex); _embedding[key] = value; } diff --git a/csrc/src/utils.cpp b/csrc/core/src/utils.cpp similarity index 100% rename from csrc/src/utils.cpp rename to csrc/core/src/utils.cpp diff --git a/csrc/include/doc_node.hpp b/csrc/include/doc_node.hpp deleted file mode 100644 index e36e7dfc9..000000000 --- a/csrc/include/doc_node.hpp +++ /dev/null @@ -1,183 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "utils.hpp" -#include "document_store.hpp" - -namespace lazyllm { - -enum class MetadataMode { ALL, EMBED, LLM, NONE }; - -class DocNode { -public: - using Metadata = std::unordered_map; - using Children = std::unordered_map>; - using EmbeddingFun = std::function(const std::string&, const std::string&)>; - using EmbeddingVec = std::unordered_map>; - - mutable std::set embedding_state = {}; - double relevance_score = .0; - double similarity_score = .0; - -private: - std::string _text = ""; - std::string _uid = ""; - std::string _group_name = ""; - std::string _parent_group_name = ""; - std::vector _content; - mutable std::string _content_hash = ""; - EmbeddingVec _embedding; - Metadata _metadata; - Metadata _global_metadata; - std::shared_ptr _store; - std::vector _excluded_embed_metadata_keys; - std::vector _excluded_llm_metadata_keys; - DocNode* _p_parent_node = nullptr; - Children _children; - bool _children_loaded = false; - -public: - DocNode() = default; - explicit DocNode( - const std::string &text = "", - const std::vector &content = {}, - const std::string &group_name = "", - const std::string &parent_group_name = "", - const DocNode *p_parent_node = nullptr, - const std::shared_ptr p_store = nullptr, - const EmbeddingVec &embedding_vec = {}, - const Metadata &metadata = {}, - const Metadata &global_metadata = {}, - const std::string &uid = "" - ) : - _text(text), - _content(content), - _uid(uid), - _group_name(group_name), - _parent_group_name(parent_group_name), - _embedding(embedding_vec), - _metadata(metadata), - _global_metadata(global_metadata), - _store(p_store), - _parent_node(p_parent_node) {} - DocNode(const DocNode&) = default; - DocNode& operator=(const DocNode&) = default; - virtual ~DocNode() = default; - - // Getter and Setter - const std::string& uid() const { return _uid; } - void set_uid(const std::string& uid) { _uid = uid; } - const std::string& group() const { return _group; } - void set_group(const std::string& group) { _group = group; } - void set_text(const std::string& text){ - _text = text; - _content.clear(); - _content_hash = ""; - } - void set_content(const std::vector& content); - const std::string& get_text() const; - virtual std::string get_text_with_metadata(MetadataMode mode) const; - void set_store(std::shared_ptr store); - const std::shared_ptr& store() const; - - std::string content_hash() const; - - EmbeddingVec& embedding(); - const EmbeddingVec& embedding() const; - void set_embedding(const EmbeddingVec& embed); - - std::vector has_missing_embedding(const std::vector& embed_keys) const; - virtual void do_embedding(const std::unordered_map& embed); - void set_embedding_value(const std::string& key, const std::vector& value); - void check_embedding_state(const std::string& embed_key) const; - - std::shared_ptr parent(); - const std::shared_ptr parent() const; - void set_parent(std::shared_ptr parent); - - Children& children(); - const Children& children() const; - void set_children(const Children& children); - - DocNode* root_node(); - const DocNode* root_node() const; - bool is_root_node() const; - - Metadata& metadata(); - const Metadata& metadata() const; - void set_metadata(const Metadata& metadata); - - Metadata& global_metadata(); - const Metadata& global_metadata() const; - void set_global_metadata(const Metadata& global_metadata); - - std::vector excluded_embed_metadata_keys() const; - void set_excluded_embed_metadata_keys(const std::vector& keys); - std::vector excluded_llm_metadata_keys() const; - void set_excluded_llm_metadata_keys(const std::vector& keys); - - std::string docpath() const; - void set_docpath(const std::string& path); - - std::string get_children_str() const; - std::string get_parent_id() const; - - std::string to_string() const; - bool operator==(const DocNode& other) const; - bool operator!=(const DocNode& other) const; - std::size_t hash() const; - - std::string get_metadata_str(MetadataMode mode = MetadataMode::ALL) const; - virtual std::string get_content(MetadataMode mode = MetadataMode::LLM) const; - - DocNode with_score(double score) const; - DocNode with_sim_score(double score) const; - - bool has_relevance_score() const; - bool has_similarity_score() const; - double relevance_score() const; - double similarity_score() const; - - void generateUUID() { _uid = GenerateUUID(); } - - -}; - -class QADocNode : public DocNode { -public: - QADocNode(const std::string& query, const std::string& answer); - QADocNode(const std::string& query, const std::string& answer, const std::string& uid, - const std::string& group = std::string()); - const std::string& answer() const; - std::string get_text_with_metadata(MetadataMode mode) const override; - -private: - std::string _answer; -}; - -class ImageDocNode : public DocNode { -public: - ImageDocNode(const std::string& image_path); - ImageDocNode(const std::string& image_path, const std::string& uid, - const std::string& group = std::string()); - - const std::string& image_path() const; - std::string get_content(MetadataMode mode = MetadataMode::LLM) const override; - void do_embedding(const std::unordered_map& embed) override; - std::string get_text_with_metadata(MetadataMode mode) const override; - -private: - std::string _image_path; - std::string _modality; -}; - -} // namespace lazyllm diff --git a/csrc/include/utils.hpp b/csrc/include/utils.hpp deleted file mode 100644 index 25a9d7681..000000000 --- a/csrc/include/utils.hpp +++ /dev/null @@ -1,224 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace lazyllm { - -std::string JoinLines(const std::vector& lines) { - std::ostringstream oss; - for (size_t i = 0; i < lines.size(); ++i) { - if (i > 0) { - oss << "\n"; - } - oss << lines[i]; - } - return oss.str(); -} - -std::string Trim(const std::string& value) { - size_t start = 0; - while (start < value.size() && std::isspace(static_cast(value[start]))) { - ++start; - } - size_t end = value.size(); - while (end > start && std::isspace(static_cast(value[end - 1]))) { - --end; - } - return value.substr(start, end - start); -} - -std::string GenerateUUID() { - static const char HEX_CHAR[] = "0123456789abcdef"; - static const int SEGS[] = {8, 4, 4, 4, 12}; - - // Single static generator per thread. - static thread_local std::mt19937 GEN(std::random_device{}()); - static thread_local std::uniform_int_distribution DIST(0, 15); - - std::string out; - out.reserve(36); - for (int segLength : SEGS) { - for (int i = 0; i < segLength; ++i) - out.push_back(HEX_CHAR[DIST(GEN)]); - if (segLength < 12) - out.push_back('-'); - } - return out; -} - -std::string ToLower(std::string value) { - std::transform(value.begin(), value.end(), value.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - return value; -} - -std::string GetExtension(const std::string& path) { - const size_t dot = path.find_last_of('.'); - if (dot == std::string::npos || dot + 1 >= path.size()) { - return ""; - } - return ToLower(path.substr(dot + 1)); -} - -std::string ImageMimeType(const std::string& path) { - const std::string ext = GetExtension(path); - if (ext == "jpg" || ext == "jpeg" || ext == "jfif" || ext == "jpe") { - return "image/jpeg"; - } - if (ext == "png" || ext == "apng") { - return "image/png"; - } - if (ext == "gif") { - return "image/gif"; - } - if (ext == "bmp" || ext == "dib") { - return "image/bmp"; - } - if (ext == "tif" || ext == "tiff") { - return "image/tiff"; - } - if (ext == "webp") { - return "image/webp"; - } - if (ext == "ico") { - return "image/x-icon"; - } - if (ext == "icns") { - return "image/icns"; - } - return ""; -} - -bool ReadFileBinary(const std::string& path, std::string* out) { - std::ifstream file(path.c_str(), std::ios::binary); - if (!file) { - return false; - } - std::ostringstream buffer; - buffer << file.rdbuf(); - *out = buffer.str(); - return true; -} - -std::string Base64Encode(const std::string& data) { - static const char kBase64Chars[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(((data.size() + 2) / 3) * 4); - int val = 0; - int valb = -6; - for (unsigned char c : data) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - out.push_back(kBase64Chars[(val >> valb) & 0x3F]); - valb -= 6; - } - } - if (valb > -6) { - out.push_back(kBase64Chars[((val << 8) >> (valb + 8)) & 0x3F]); - } - while (out.size() % 4) { - out.push_back('='); - } - return out; -} - -uint32_t RotateRight(uint32_t value, uint32_t bits) { - return (value >> bits) | (value << (32 - bits)); -} - -std::string Sha256Hex(const std::string& input) { - static const uint32_t k[64] = { - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, - 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, - 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, - 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, - 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, - 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, - 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, - 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, - 0xc67178f2, - }; - uint32_t h[8] = { - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, - }; - - std::vector msg(input.begin(), input.end()); - const uint64_t bit_len = static_cast(msg.size()) * 8; - msg.push_back(0x80); - while ((msg.size() % 64) != 56) { - msg.push_back(0x00); - } - for (int i = 7; i >= 0; --i) { - msg.push_back(static_cast((bit_len >> (i * 8)) & 0xFF)); - } - - for (size_t offset = 0; offset < msg.size(); offset += 64) { - uint32_t w[64]; - for (size_t i = 0; i < 16; ++i) { - const size_t idx = offset + i * 4; - w[i] = (static_cast(msg[idx]) << 24) | - (static_cast(msg[idx + 1]) << 16) | - (static_cast(msg[idx + 2]) << 8) | - (static_cast(msg[idx + 3])); - } - for (size_t i = 16; i < 64; ++i) { - const uint32_t s0 = RotateRight(w[i - 15], 7) ^ RotateRight(w[i - 15], 18) ^ (w[i - 15] >> 3); - const uint32_t s1 = RotateRight(w[i - 2], 17) ^ RotateRight(w[i - 2], 19) ^ (w[i - 2] >> 10); - w[i] = w[i - 16] + s0 + w[i - 7] + s1; - } - - uint32_t a = h[0]; - uint32_t b = h[1]; - uint32_t c = h[2]; - uint32_t d = h[3]; - uint32_t e = h[4]; - uint32_t f = h[5]; - uint32_t g = h[6]; - uint32_t h0 = h[7]; - - for (size_t i = 0; i < 64; ++i) { - const uint32_t s1 = RotateRight(e, 6) ^ RotateRight(e, 11) ^ RotateRight(e, 25); - const uint32_t ch = (e & f) ^ (~e & g); - const uint32_t temp1 = h0 + s1 + ch + k[i] + w[i]; - const uint32_t s0 = RotateRight(a, 2) ^ RotateRight(a, 13) ^ RotateRight(a, 22); - const uint32_t maj = (a & b) ^ (a & c) ^ (b & c); - const uint32_t temp2 = s0 + maj; - - h0 = g; - g = f; - f = e; - e = d + temp1; - d = c; - c = b; - b = a; - a = temp1 + temp2; - } - - h[0] += a; - h[1] += b; - h[2] += c; - h[3] += d; - h[4] += e; - h[5] += f; - h[6] += g; - h[7] += h0; - } - - std::ostringstream oss; - oss << std::hex << std::setfill('0'); - for (size_t i = 0; i < 8; ++i) { - oss << std::setw(8) << h[i]; - } - return oss.str(); -} - -} // namespace lazyllm diff --git a/csrc/src/README.md b/csrc/src/README.md deleted file mode 100644 index e69de29bb..000000000 From 1854448b3e3428edfa53c89d63ed2e7b3fee1605 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 30 Jan 2026 18:00:17 +0800 Subject: [PATCH 09/40] adaptor --- csrc/CMakeLists.txt | 3 +- csrc/adaptor/adaptor.cpp | 2 + csrc/adaptor/adaptor_base_wrapper.hpp | 35 +++++++++++++ csrc/adaptor/adptor.cpp | 1 - csrc/adaptor/document_store.hpp | 74 ++++++++++++--------------- csrc/core/include/adaptor_base.hpp | 21 ++++++++ 6 files changed, 93 insertions(+), 43 deletions(-) create mode 100644 csrc/adaptor/adaptor.cpp create mode 100644 csrc/adaptor/adaptor_base_wrapper.hpp delete mode 100644 csrc/adaptor/adptor.cpp create mode 100644 csrc/core/include/adaptor_base.hpp diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 3ee930048..a824c99c5 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -9,6 +9,7 @@ add_compile_options( -Wshadow ) +# Third party libs include(FetchContent) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) find_package(pybind11 CONFIG REQUIRED) @@ -28,7 +29,7 @@ endif() file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/core/src/*.cpp") add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) -target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/include) target_link_libraries(lazyllm_core PUBLIC xxhash) # Config lazyllm_adaptor lib which maintains callback invocations. diff --git a/csrc/adaptor/adaptor.cpp b/csrc/adaptor/adaptor.cpp new file mode 100644 index 000000000..cd57a3bbe --- /dev/null +++ b/csrc/adaptor/adaptor.cpp @@ -0,0 +1,2 @@ +#include "adaptor_base_wrapper.hpp" +#include "document_store.hpp" diff --git a/csrc/adaptor/adaptor_base_wrapper.hpp b/csrc/adaptor/adaptor_base_wrapper.hpp new file mode 100644 index 000000000..e608097e6 --- /dev/null +++ b/csrc/adaptor/adaptor_base_wrapper.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "adaptor_base.hpp" + + +namespace lazyllm { + +class AdaptorBaseWrapper : public AdaptorBase { + pybind11::object _py_obj; +public: + AdaptorBaseWrapper(const pybind11::object &obj) : _py_obj(obj) {} + std::any call( + const std::string& func_name, + const std::unordered_map& args) const override final + { + pybind11::gil_scoped_acquire gil; + pybind11::object func = _py_obj.attr(func_name.c_str()); + return call_impl(func_name, func, args); + } + + virtual std::any call_impl( + const std::string& func_name, + const pybind11::object& func, + const std::unordered_map& args) const = 0; +}; + +} \ No newline at end of file diff --git a/csrc/adaptor/adptor.cpp b/csrc/adaptor/adptor.cpp deleted file mode 100644 index 3e54ade0a..000000000 --- a/csrc/adaptor/adptor.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "document_store.hpp" diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index c12d1800d..ee4d3902f 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -9,6 +8,9 @@ #include #include +#include "adaptor_base_wrapper.hpp" +#include "doc_node.hpp" + namespace lazyllm { struct NodeGroup { @@ -25,7 +27,7 @@ struct NodeGroup { parent(parent), display_name(display_name), type(type) {} }; -class DocumentStore { +class DocumentStore : public AdaptorBaseWrapper { public: // RAG system metadata keys static constexpr std::string_view RAG_KB_ID_KEY = "kb_id"; @@ -33,13 +35,13 @@ class DocumentStore { DocumentStore() = delete; explicit DocumentStore( - const pybind11::object &store, + const pybind11::object& store, const std::unordered_map &map) : - _py_store(store), _node_groups_map(map) {} + AdaptorBaseWrapper(store), _node_groups_map(map) {} - // Cache-aware factory to avoid rebuilding wrappers for the same Python store. + // Cache-aware factory to avoid rebuilding adaptor for the same Python store. static std::shared_ptr from_store( - const pybind11::object &store, const std::unordered_map &map) { + const pybind11::object& store, const std::unordered_map& map) { if (store.is_none()) return nullptr; pybind11::gil_scoped_acquire gil; @@ -50,49 +52,39 @@ class DocumentStore { if (auto existing = it->second.lock()) return existing; } - auto created = std::shared_ptr(new DocumentStore(store, map)); + auto created = std::make_shared(store, map); cache[key] = created; return created; } - bool is_group_active(const std::string& grp) const { - pybind11::gil_scoped_acquire gil; - pybind11::object fn = _py_store.attr("is_group_active"); - return fn(grp).cast(); - } - - pybind11::object get_node( - const std::string& group_name, - const std::string& uid, - const std::string& kb_id - ) const { - pybind11::gil_scoped_acquire gil; - pybind11::object fn = _py_store.attr("get_nodes"); - pybind11::object result = fn( - pybind11::arg("group_name") = group_name, - pybind11::arg("uids") = std::vector({uid}), - pybind11::arg("kb_id") = kb_id, - pybind11::arg("display") = true); - return result.cast()[0]; - } + std::any call_impl( + const std::string& func_name, + const pybind11::object& func, + const std::unordered_map& args) const override + { + if (func_name == "is_group_active") { + return func(args.at("group")).cast(); + } + else if (func_name == "get_node") { + return func( + pybind11::arg("group_name") = std::any_cast(args.at("group_name")), + pybind11::arg("uids") = std::vector({std::any_cast(args.at("uid"))}), + pybind11::arg("kb_id") = std::any_cast(args.at("kb_id")), + pybind11::arg("display") = true + ).cast()[0].cast(); + } + else if (func_name == "get_nodes") { + return func( + pybind11::arg("group_name") = std::any_cast(args.at("group_name")), + pybind11::arg("kb_id") = std::any_cast(args.at("kb_id")), + pybind11::arg("doc_ids") = std::vector({std::any_cast(args.at("doc_id"))}) + ).cast>(); + } - std::vector get_nodes( - const std::string& group_name, - const std::string& kb_id, - const std::string& doc_id - ) const { - pybind11::gil_scoped_acquire gil; - pybind11::object fn = _py_store.attr("get_nodes"); - pybind11::object result = fn( - pybind11::arg("group_name") = group_name, - pybind11::arg("kb_id") = kb_id, - pybind11::arg("doc_ids") = std::vector({doc_id})); - return result.cast>(); + throw std::runtime_error("Unknown DocumentStore function: " + func_name); } private: - // Keep the underlying Python store object alive for callback invocations. - pybind11::object _py_store; std::unordered_map _node_groups_map; // Cache by Python object identity to ensure one wrapper per store instance. diff --git a/csrc/core/include/adaptor_base.hpp b/csrc/core/include/adaptor_base.hpp new file mode 100644 index 000000000..bee43cbaa --- /dev/null +++ b/csrc/core/include/adaptor_base.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include +#include +#include + +namespace lazyllm { + +struct Arg { + std::string name; + std::any value; +}; + +struct AdaptorBase { + virtual ~AdaptorBase() = default; + virtual std::any call( + const std::string& func_name, + const std::unordered_map& args) const = 0; +}; + +} // namespace lazyllm From 1484fc7365182d85c9416a9f3e2d79470e176894 Mon Sep 17 00:00:00 2001 From: yzh Date: Mon, 2 Feb 2026 17:32:21 +0800 Subject: [PATCH 10/40] finish doc_node init --- csrc/adaptor/document_store.hpp | 6 +-- csrc/binding/export_doc_node.cpp | 80 ++++++++++---------------------- csrc/core/include/doc_node.hpp | 30 ++++++------ csrc/core/include/utils.hpp | 13 +++--- 4 files changed, 49 insertions(+), 80 deletions(-) diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index ee4d3902f..9e8868c03 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -57,6 +57,9 @@ class DocumentStore : public AdaptorBaseWrapper { return created; } +private: + std::unordered_map _node_groups_map; + std::any call_impl( const std::string& func_name, const pybind11::object& func, @@ -84,9 +87,6 @@ class DocumentStore : public AdaptorBaseWrapper { throw std::runtime_error("Unknown DocumentStore function: " + func_name); } -private: - std::unordered_map _node_groups_map; - // Cache by Python object identity to ensure one wrapper per store instance. static std::unordered_map> &store_cache() { static std::unordered_map> cache; diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index aaf5051c3..d528695dc 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -1,5 +1,4 @@ #include -#include #include "lazyllm.hpp" #include "document_store.hpp" @@ -21,90 +20,61 @@ lazyllm::DocNode init( std::optional global_metadata, std::optional text ) { - if (content && text) { + if (content && text) throw std::invalid_argument("`text` and `content` cannot be set at the same time."); - } - // Find parent node. lazyllm::DocNode* p_parent_node = nullptr; - std::shared_ptr store_bridge = nullptr; + std::shared_ptr store_adaptor = nullptr; // Build node groups map. // Usually, parent + store + node_groups are not None at the same time. if (parent && !store.is_none() && node_groups && global_metadata && group) { std::unordered_map node_groups_map; node_groups_map.reserve(node_groups->size()); - for (const auto &entry : *node_groups) { - auto &group_key = entry.first; - auto &group_dict = entry.second; + for (const auto& [group_key, group_dict] : *node_groups) { node_groups_map.emplace(group_key, lazyllm::NodeGroup( std::any_cast(group_dict.at(std::string("parent"))), std::any_cast(group_dict.at(std::string("display_name"))) )); } - auto store_bridge = lazyllm::DocumentStore::from_store(store, node_groups_map); + store_adaptor = lazyllm::DocumentStore::from_store(store, node_groups_map); auto kb_id = std::any_cast((*global_metadata)[ std::string(lazyllm::DocumentStore::RAG_KB_ID_KEY)]); auto doc_id = std::any_cast((*global_metadata)[ std::string(lazyllm::DocumentStore::RAG_DOC_ID_KEY)]); - if (std::holds_alternative(*parent)) { - const std::string& parent_uid = std::get(*parent); - p_parent_node = store_bridge->get_node(parent_uid, *group, kb_id).cast(); + if (const auto* parent_uid = std::get_if(&*parent)) { + p_parent_node = std::any_cast(store_adaptor->call("get_node", + {{"group_name", *group}, {"uid", *parent_uid}, {"kb_id", kb_id}})); } else p_parent_node = std::get(*parent).cast(); } + std::string raw_text; + if (content) { + if (const auto* single_text = std::get_if(&*content)) + raw_text = std::move(*single_text); + else + raw_text = lazyllm::JoinLines(std::get>(*content)); + } + else if (text){ + raw_text = std::move(*text); + } + lazyllm::DocNode node( - content.value_or(text.value_or(std::string(""))), - uid.value_or(lazyllm::GenerateUUID()), + std::string_view(raw_text), group.value_or(""), + uid.value_or(""), p_parent_node, - store_bridge, - embedding.value_or(lazyllm::DocNode::EmbeddingVec()), metadata.value_or(lazyllm::DocNode::Metadata()), - global_metadata.value_or(lazyllm::DocNode::Metadata()) + std::make_shared( + global_metadata.value_or(lazyllm::DocNode::Metadata())) ); - - - if (embedding.has_value()) { - node.set_embedding(*embedding); - } - - if (!metadata.is_none()) { - if (!py::isinstance(metadata)) { - throw std::invalid_argument("`metadata` must be a dict."); - } - node.set_metadata(dict_to_metadata(metadata.cast())); - } - - if (!global_metadata.is_none()) { - if (!py::isinstance(global_metadata)) { - throw std::invalid_argument("`global_metadata` must be a dict."); - } - node.set_global_metadata(dict_to_metadata(global_metadata.cast())); - } - - if (!parent.is_none()) { - if (py::isinstance(parent)) { - node.set_parent(parent.cast>()); - } else if (py::isinstance(parent)) { - throw std::invalid_argument("`parent` as str is not supported in C++ binding."); - } else { - throw std::invalid_argument("`parent` must be a DocNode or None."); - } - } - - std::shared_ptr store_bridge; - if (!store.is_none()) { - py::object store_obj = store; - store_bridge = lazyllm::DocumentStore::from_store(store_obj, node_groups_map); - } - node.set_store(std::move(store_bridge)); - - (void)node_groups; + if (store_adaptor) node.set_store(store_adaptor); + if (embedding) node.set_embedding_vec(*embedding); + if (!raw_text.empty()) node.set_root_text(std::move(raw_text)); return node; } diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 11597ca92..bcb7577cf 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -13,6 +13,7 @@ #include #include "utils.hpp" +#include "adaptor_base.hpp" namespace lazyllm { @@ -25,7 +26,6 @@ class DocNode { using EmbeddingFun = std::function(const std::string&, const std::string&)>; using EmbeddingVec = std::unordered_map>; - DocNode* _p_parent_node = nullptr; private: std::shared_ptr _p_root_text = nullptr; std::string_view _text_view; @@ -34,7 +34,7 @@ class DocNode { mutable size_t _text_hash = 0; Metadata _metadata; - std::shared_ptr _p_global_metadata; + std::shared_ptr _p_global_metadata; std::vector _excluded_embed_metadata_keys; std::vector _excluded_llm_metadata_keys; @@ -43,7 +43,9 @@ class DocNode { double _relevance_score = .0; double _similarity_score = .0; + const DocNode* _p_parent_node = nullptr; Children _children; + std::shared_ptr _p_store = nullptr; public: DocNode() = delete; @@ -51,21 +53,17 @@ class DocNode { const std::string_view& text_view, const std::string& group_name, const std::string& uid = "", - const EmbeddingVec& embedding_vec = {}, + const DocNode* p_parent_node = nullptr, const Metadata& metadata = {}, - const std::shared_ptr& global_metadata = {}, - const std::shared_ptr& p_raw_text = nullptr + const std::shared_ptr& global_metadata = {} ) : _text_view(text_view), _group_name(group_name), - _uid(uid), + _uid(uid.empty() ? GenerateUUID() : uid), + _p_parent_node(p_parent_node), _metadata(metadata), - _p_global_metadata(global_metadata), - _embedding_vec(embedding_vec), - _p_root_text(p_raw_text) - { - if (uid.empty()) _uid = GenerateUUID(); - } + _p_global_metadata(global_metadata) + {} DocNode(const DocNode&) = default; DocNode& operator=(const DocNode&) = default; @@ -77,18 +75,19 @@ class DocNode { std::vector find_children(const std::vector& nodes) const { std::vector children; for (auto p_node : nodes) - if (p_node->get_parent_node() == _p_parent_node) + if (p_node->get_p_parent_node() == _p_parent_node) children.push_back(p_node); return children; } // Getter and Setter + void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } const std::string& get_uid() const { return _uid; } const std::string& get_group_name() const { return _group_name; } const std::string_view& get_text_view() const { return _text_view; } const std::string& get_text() const { return std::string(_text_view); } - void set_text(const std::string& text) { - _p_root_text = std::make_shared(text); + void set_root_text(const std::string&& text) { + _p_root_text = std::make_shared(std::move(text)); _text_view = *_p_root_text; _text_hash = evaluate_text_hash(); } @@ -96,6 +95,7 @@ class DocNode { if (_text_hash == 0) _text_hash = evaluate_text_hash(); return _text_hash; } + void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } const DocNode* get_parent_node() { return _p_parent_node; } void set_parent_node(DocNode* p_parent_node) { _p_parent_node = p_parent_node; } diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 76d50c6d1..384e251d5 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -10,14 +10,13 @@ namespace lazyllm { std::string JoinLines(const std::vector& lines) { - std::ostringstream oss; - for (size_t i = 0; i < lines.size(); ++i) { - if (i > 0) { - oss << "\n"; - } - oss << lines[i]; + if (lines.empty()) return {}; + std::string out = lines.front(); + for (size_t i = 1; i < lines.size(); ++i) { + out += '\n'; + out += lines[i]; } - return oss.str(); + return out; } std::string Trim(const std::string& value) { From a69f82f873135ce9f43a41eef8f66aa984da75a7 Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 3 Feb 2026 15:04:36 +0800 Subject: [PATCH 11/40] children --- csrc/adaptor/adaptor_base_wrapper.hpp | 2 +- csrc/adaptor/document_store.hpp | 30 ++++++++++-- csrc/binding/export_doc_node.cpp | 34 +++++++------- csrc/core/include/doc_node.hpp | 67 ++++++++++----------------- csrc/core/include/utils.hpp | 15 +++++- 5 files changed, 82 insertions(+), 66 deletions(-) diff --git a/csrc/adaptor/adaptor_base_wrapper.hpp b/csrc/adaptor/adaptor_base_wrapper.hpp index e608097e6..1b1da0a39 100644 --- a/csrc/adaptor/adaptor_base_wrapper.hpp +++ b/csrc/adaptor/adaptor_base_wrapper.hpp @@ -22,7 +22,7 @@ class AdaptorBaseWrapper : public AdaptorBase { const std::unordered_map& args) const override final { pybind11::gil_scoped_acquire gil; - pybind11::object func = _py_obj.attr(func_name.c_str()); + pybind11::object func = pybind11::getattr(_py_obj, func_name.c_str(), pybind11::none()); return call_impl(func_name, func, args); } diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index 9e8868c03..1564a52cc 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -29,10 +29,6 @@ struct NodeGroup { class DocumentStore : public AdaptorBaseWrapper { public: - // RAG system metadata keys - static constexpr std::string_view RAG_KB_ID_KEY = "kb_id"; - static constexpr std::string_view RAG_DOC_ID_KEY = "docid"; - DocumentStore() = delete; explicit DocumentStore( const pybind11::object& store, @@ -57,6 +53,29 @@ class DocumentStore : public AdaptorBaseWrapper { return created; } + DocNode::Children get_node_children(const DocNode* node) const { + DocNode::Children out; + auto& kb_id = std::any_cast(node->_p_global_metadata->at(std::string(RAG_KEY_KB_ID))); + auto& doc_id = std::any_cast(node->_p_global_metadata->at(std::string(RAG_KEY_DOC_ID))); + auto& group_name = node->get_group_name(); + for(auto& [current_group_name, group] : _node_groups_map) { + if (group.parent != group_name) continue; + if (!std::any_cast(call("is_group_active", {{"group", current_group_name}}))) continue; + auto nodes_in_group = std::any_cast>(call("get_nodes", { + {"group_name", current_group_name}, + {"kb_id", kb_id}, + {"doc_ids", std::vector({doc_id})} + })); + + std::vector children; + children.reserve(nodes_in_group.size()); + for (auto* n : nodes_in_group) + if (n->get_parent_node() == node) children.push_back(n); + out[current_group_name] = children; + } + return out; + } + private: std::unordered_map _node_groups_map; @@ -83,6 +102,9 @@ class DocumentStore : public AdaptorBaseWrapper { pybind11::arg("doc_ids") = std::vector({std::any_cast(args.at("doc_id"))}) ).cast>(); } + else if (func_name == "get_node_children") { + return get_node_children(std::any_cast(args.at("node"))); + } throw std::runtime_error("Unknown DocumentStore function: " + func_name); } diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index d528695dc..34ddccbb7 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -39,10 +39,10 @@ lazyllm::DocNode init( } store_adaptor = lazyllm::DocumentStore::from_store(store, node_groups_map); - auto kb_id = std::any_cast((*global_metadata)[ - std::string(lazyllm::DocumentStore::RAG_KB_ID_KEY)]); - auto doc_id = std::any_cast((*global_metadata)[ - std::string(lazyllm::DocumentStore::RAG_DOC_ID_KEY)]); + auto kb_id = std::any_cast((*global_metadata).at( + std::string(lazyllm::RAG_KEY_KB_ID))); + auto doc_id = std::any_cast((*global_metadata).at( + std::string(lazyllm::RAG_KEY_DOC_ID))); if (const auto* parent_uid = std::get_if(&*parent)) { p_parent_node = std::any_cast(store_adaptor->call("get_node", @@ -53,18 +53,10 @@ lazyllm::DocNode init( } std::string raw_text; - if (content) { - if (const auto* single_text = std::get_if(&*content)) - raw_text = std::move(*single_text); - else - raw_text = lazyllm::JoinLines(std::get>(*content)); - } - else if (text){ - raw_text = std::move(*text); - } + lazyllm::DocNode node( - std::string_view(raw_text), + "", group.value_or(""), uid.value_or(""), p_parent_node, @@ -74,7 +66,15 @@ lazyllm::DocNode init( ); if (store_adaptor) node.set_store(store_adaptor); if (embedding) node.set_embedding_vec(*embedding); - if (!raw_text.empty()) node.set_root_text(std::move(raw_text)); + if (content) { + if (const auto* s = std::get_if(&*content)) + node.set_root_text(std::move(*s)); + else + node.set_root_texts(std::get>(*content)); + } + else if (text){ + node.set_root_text(std::move(*text)); + } return node; } @@ -97,7 +97,9 @@ void exportDocNode(py::module& m) { py::keep_alive<1, 7>() // Keep store alive ) .def_property_readonly("uid", &lazyllm::DocNode::uid) - number + @property + def number(self) -> int: + return self._metadata.get('lazyllm_store_num', 0) .def("set_text", &lazyllm::DocNode::set_text, py::arg("text")) .def("get_text", &lazyllm::DocNode::get_text); } diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index bcb7577cf..39c346421 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -22,29 +22,30 @@ enum class MetadataMode { ALL, EMBED, LLM, NONE }; class DocNode { public: using Metadata = std::unordered_map; - using Children = std::unordered_map>; + using Children = std::unordered_map>; using EmbeddingFun = std::function(const std::string&, const std::string&)>; using EmbeddingVec = std::unordered_map>; + Metadata _metadata; + std::shared_ptr _p_global_metadata; + EmbeddingVec _embedding_vec; private: - std::shared_ptr _p_root_text = nullptr; std::string_view _text_view; + std::shared_ptr _p_root_text = nullptr; + std::vector _root_texts = {}; std::string _group_name; std::string _uid; mutable size_t _text_hash = 0; - Metadata _metadata; - std::shared_ptr _p_global_metadata; std::vector _excluded_embed_metadata_keys; std::vector _excluded_llm_metadata_keys; - EmbeddingVec _embedding_vec; mutable std::set _embedding_state = {}; double _relevance_score = .0; double _similarity_score = .0; const DocNode* _p_parent_node = nullptr; - Children _children; + mutable Children _children; std::shared_ptr _p_store = nullptr; public: @@ -55,15 +56,16 @@ class DocNode { const std::string& uid = "", const DocNode* p_parent_node = nullptr, const Metadata& metadata = {}, - const std::shared_ptr& global_metadata = {} + const std::shared_ptr& global_metadata = {} ) : - _text_view(text_view), _group_name(group_name), _uid(uid.empty() ? GenerateUUID() : uid), _p_parent_node(p_parent_node), _metadata(metadata), _p_global_metadata(global_metadata) - {} + { + set_text_view(text_view); + } DocNode(const DocNode&) = default; DocNode& operator=(const DocNode&) = default; @@ -72,50 +74,35 @@ class DocNode { size_t evaluate_text_hash() const { return static_cast(XXH64(_text_view.data(), _text_view.size(), 0)); } - std::vector find_children(const std::vector& nodes) const { - std::vector children; - for (auto p_node : nodes) - if (p_node->get_p_parent_node() == _p_parent_node) - children.push_back(p_node); - return children; - } // Getter and Setter void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } const std::string& get_uid() const { return _uid; } const std::string& get_group_name() const { return _group_name; } const std::string_view& get_text_view() const { return _text_view; } + void set_text_view(const std::string_view& text_view) { + _text_view = text_view; + _text_hash = evaluate_text_hash(); + } const std::string& get_text() const { return std::string(_text_view); } void set_root_text(const std::string&& text) { _p_root_text = std::make_shared(std::move(text)); - _text_view = *_p_root_text; - _text_hash = evaluate_text_hash(); - } - const size_t& get_text_hash() const { - if (_text_hash == 0) _text_hash = evaluate_text_hash(); - return _text_hash; + set_text_view(*_p_root_text); } - void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } + void set_root_texts(const std::vector& texts) { set_root_text(JoinLines(texts)); } + size_t text_hash() const { return _text_hash; } const DocNode* get_parent_node() { return _p_parent_node; } void set_parent_node(DocNode* p_parent_node) { _p_parent_node = p_parent_node; } - - const std::vector& get_children() const { + const Children& get_children() const { + if (!_children.empty()) return _children; + if (_p_store == nullptr) return Children(); + _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); return _children; } + void set_children(const Children& children) { _children = children; } + - const Metadata& get_metadata() const { return _metadata; } - void set_metadata(const Metadata& metadata) { _metadata = metadata; } - std::shared_ptr global_metadata_ptr() { return _p_global_metadata; } - std::shared_ptr global_metadata_ptr() const { return _p_global_metadata; } - Metadata& global_metadata() { return *_p_global_metadata; } - const Metadata& global_metadata() const { return *_p_global_metadata; } - void set_global_metadata_ptr(std::shared_ptr p_global_metadata) { - _p_global_metadata = std::move(p_global_metadata); - } - void set_global_metadata(const Metadata& global_metadata) { - _p_global_metadata = std::make_shared(global_metadata); - } const std::vector& excluded_embed_metadata_keys() const { return _excluded_embed_metadata_keys; @@ -131,9 +118,6 @@ class DocNode { _excluded_llm_metadata_keys = keys; } - EmbeddingVec& embedding_vec() { return _embedding_vec; } - const EmbeddingVec& embedding_vec() const { return _embedding_vec; } - void set_embedding_vec(const EmbeddingVec& embedding_vec) { _embedding_vec = embedding_vec; } std::set& embedding_state() { return _embedding_state; } const std::set& embedding_state() const { return _embedding_state; } @@ -154,10 +138,7 @@ class DocNode { Children& children() { return _children; } const Children& children() const { return _children; } - void set_children(const Children& children) { _children = children; } - bool children_loaded() const { return _children_loaded; } - void set_children_loaded(bool children_loaded) { _children_loaded = children_loaded; } }; } // namespace lazyllm diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 384e251d5..56cf1c455 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -9,11 +9,22 @@ namespace lazyllm { -std::string JoinLines(const std::vector& lines) { +// RAG system metadata keys +constexpr std::string_view RAG_KEY_KB_ID = "kb_id"; +constexpr std::string_view RAG_KEY_DOC_ID = "docid"; +constexpr std::string_view RAG_KEY_DOC_PATH = "lazyllm_doc_path"; +constexpr std::string_view RAG_KEY_DOC_FILE_NAME = "file_name"; +constexpr std::string_view RAG_KEY_DOC_FILE_TYPE = "file_type"; +constexpr std::string_view RAG_KEY_DOC_FILE_SIZE = "file_size"; +constexpr std::string_view RAG_KEY_DOC_CREATION_DATE = "creation_date"; +constexpr std::string_view RAG_KEY_DOC_LAST_MODIFIED_DATE = "last_modified_date"; +constexpr std::string_view RAG_KEY_DOC_LAST_ACCESSED_DATE = "last_accessed_date"; + +std::string JoinLines(const std::vector& lines, char delim = '\n') { if (lines.empty()) return {}; std::string out = lines.front(); for (size_t i = 1; i < lines.size(); ++i) { - out += '\n'; + out += delim; out += lines[i]; } return out; From a6cfceb45a08c455055b6ec46dec7f965551a502 Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 4 Feb 2026 16:21:17 +0800 Subject: [PATCH 12/40] doc_node hpp --- csrc/adaptor/adaptor_base_wrapper.hpp | 6 +- csrc/adaptor/document_store.hpp | 12 +- csrc/binding/export_doc_node.cpp | 19 +- csrc/core/include/adaptor_base.hpp | 9 +- csrc/core/include/doc_node.hpp | 121 +++++--- csrc/core/include/utils.hpp | 23 ++ csrc/core/src/doc_node.cpp | 405 -------------------------- 7 files changed, 128 insertions(+), 467 deletions(-) diff --git a/csrc/adaptor/adaptor_base_wrapper.hpp b/csrc/adaptor/adaptor_base_wrapper.hpp index 1b1da0a39..bafd78a60 100644 --- a/csrc/adaptor/adaptor_base_wrapper.hpp +++ b/csrc/adaptor/adaptor_base_wrapper.hpp @@ -13,10 +13,12 @@ namespace lazyllm { -class AdaptorBaseWrapper : public AdaptorBase { +class LAZYLLM_HIDDEN AdaptorBaseWrapper : public AdaptorBase { pybind11::object _py_obj; public: AdaptorBaseWrapper(const pybind11::object &obj) : _py_obj(obj) {} + virtual ~AdaptorBaseWrapper() = default; + std::any call( const std::string& func_name, const std::unordered_map& args) const override final @@ -32,4 +34,4 @@ class AdaptorBaseWrapper : public AdaptorBase { const std::unordered_map& args) const = 0; }; -} \ No newline at end of file +} diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index 1564a52cc..fff7477b3 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -17,17 +17,17 @@ struct NodeGroup { enum class Type { ORIGINAL, CHUNK, SUMMARY, IMAGE_INFO, QUESTION_ANSWER, OTHER }; - std::string parent; - std::string display_name; - Type type; + std::string _parent; + std::string _display_name; + Type _type; NodeGroup( const std::string& parent, const std::string& display_name, const Type& type = Type::ORIGINAL) : - parent(parent), display_name(display_name), type(type) {} + _parent(parent), _display_name(display_name), _type(type) {} }; -class DocumentStore : public AdaptorBaseWrapper { +class LAZYLLM_HIDDEN DocumentStore : public AdaptorBaseWrapper { public: DocumentStore() = delete; explicit DocumentStore( @@ -59,7 +59,7 @@ class DocumentStore : public AdaptorBaseWrapper { auto& doc_id = std::any_cast(node->_p_global_metadata->at(std::string(RAG_KEY_DOC_ID))); auto& group_name = node->get_group_name(); for(auto& [current_group_name, group] : _node_groups_map) { - if (group.parent != group_name) continue; + if (group._parent != group_name) continue; if (!std::any_cast(call("is_group_active", {{"group", current_group_name}}))) continue; auto nodes_in_group = std::any_cast>(call("get_nodes", { {"group_name", current_group_name}, diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 34ddccbb7..3fa2ef2db 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -11,7 +11,7 @@ lazyllm::DocNode init( std::optional uid, std::optional>> content, std::optional group, - std::optional embedding, + std::optional embedding, std::optional> parent, py::object store, std::optional(&*content)) node.set_root_text(std::move(*s)); @@ -92,14 +95,6 @@ void exportDocNode(py::module& m) { py::arg("node_groups") = py::none(), py::arg("metadata") = py::none(), py::arg("global_metadata") = py::none(), - py::arg("text") = py::none(), - - py::keep_alive<1, 7>() // Keep store alive - ) - .def_property_readonly("uid", &lazyllm::DocNode::uid) - @property - def number(self) -> int: - return self._metadata.get('lazyllm_store_num', 0) - .def("set_text", &lazyllm::DocNode::set_text, py::arg("text")) - .def("get_text", &lazyllm::DocNode::get_text); + py::arg("text") = py::none() + ); } diff --git a/csrc/core/include/adaptor_base.hpp b/csrc/core/include/adaptor_base.hpp index bee43cbaa..6bb729403 100644 --- a/csrc/core/include/adaptor_base.hpp +++ b/csrc/core/include/adaptor_base.hpp @@ -4,6 +4,12 @@ #include #include +#if defined(__GNUC__) || defined(__clang__) +#define LAZYLLM_HIDDEN __attribute__((visibility("hidden"))) +#else +#define LAZYLLM_HIDDEN +#endif + namespace lazyllm { struct Arg { @@ -11,7 +17,8 @@ struct Arg { std::any value; }; -struct AdaptorBase { +class AdaptorBase { +public: virtual ~AdaptorBase() = default; virtual std::any call( const std::string& func_name, diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 39c346421..9d8dbae66 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -24,11 +24,15 @@ class DocNode { using Metadata = std::unordered_map; using Children = std::unordered_map>; using EmbeddingFun = std::function(const std::string&, const std::string&)>; - using EmbeddingVec = std::unordered_map>; + using EmbeddingVecs = std::unordered_map>; Metadata _metadata; std::shared_ptr _p_global_metadata; - EmbeddingVec _embedding_vec; + EmbeddingVecs _embedding_vecs; + std::set _pending_embedding_keys = {}; + double _relevance_score = .0; + double _similarity_score = .0; + private: std::string_view _text_view; std::shared_ptr _p_root_text = nullptr; @@ -37,12 +41,8 @@ class DocNode { std::string _uid; mutable size_t _text_hash = 0; - std::vector _excluded_embed_metadata_keys; - std::vector _excluded_llm_metadata_keys; - - mutable std::set _embedding_state = {}; - double _relevance_score = .0; - double _similarity_score = .0; + std::set _excluded_embed_metadata_keys; + std::set _excluded_llm_metadata_keys; const DocNode* _p_parent_node = nullptr; mutable Children _children; @@ -76,6 +76,10 @@ class DocNode { } // Getter and Setter + const DocNode* get_root_node() const { + if (_p_parent_node == nullptr) return this; + return _p_parent_node->get_root_node(); + } void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } const std::string& get_uid() const { return _uid; } const std::string& get_group_name() const { return _group_name; } @@ -84,61 +88,96 @@ class DocNode { _text_view = text_view; _text_hash = evaluate_text_hash(); } - const std::string& get_text() const { return std::string(_text_view); } + std::string get_metadata_string(MetadataMode mode = MetadataMode::ALL) const { + if (mode == MetadataMode::NONE) return ""; + + std::set valid_keys; + for (const auto& [key, _] : _metadata) valid_keys.insert(key); + + if (mode == MetadataMode::LLM) + valid_keys = SetDiff(valid_keys, _excluded_llm_metadata_keys); + else if (mode == MetadataMode::EMBED) + valid_keys = SetDiff(valid_keys, _excluded_embed_metadata_keys); + + std::vector kv_strings; + for (const std::string& key : valid_keys) + kv_strings.emplace_back(key + ":" + std::any_cast(_metadata.at(key))); + + return JoinLines(kv_strings); + } + std::string get_text(MetadataMode mode = MetadataMode::NONE) const { + if (mode == MetadataMode::NONE) return std::string(_text_view); + const auto& metadata_string = get_metadata_string(mode); + return metadata_string + "\n\n" + std::string(_text_view); + } void set_root_text(const std::string&& text) { _p_root_text = std::make_shared(std::move(text)); set_text_view(*_p_root_text); } void set_root_texts(const std::vector& texts) { set_root_text(JoinLines(texts)); } size_t text_hash() const { return _text_hash; } - const DocNode* get_parent_node() { return _p_parent_node; } + const DocNode* get_parent_node() const { return _p_parent_node; } void set_parent_node(DocNode* p_parent_node) { _p_parent_node = p_parent_node; } - const Children& get_children() const { + Children py_get_children() const { if (!_children.empty()) return _children; if (_p_store == nullptr) return Children(); _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); return _children; } void set_children(const Children& children) { _children = children; } - - - - - const std::vector& excluded_embed_metadata_keys() const { - return _excluded_embed_metadata_keys; + std::set get_excluded_embed_metadata_keys() const { + return SetUnion(get_root_node()->get_excluded_embed_metadata_keys(), _excluded_embed_metadata_keys); } - void set_excluded_embed_metadata_keys(const std::vector& keys) { + void set_excluded_embed_metadata_keys(const std::set& keys) { _excluded_embed_metadata_keys = keys; } - - const std::vector& excluded_llm_metadata_keys() const { - return _excluded_llm_metadata_keys; + std::set get_excluded_llm_metadata_keys() const { + return SetUnion(get_root_node()->get_excluded_llm_metadata_keys(), _excluded_llm_metadata_keys); } - void set_excluded_llm_metadata_keys(const std::vector& keys) { + void set_excluded_llm_metadata_keys(const std::set& keys) { _excluded_llm_metadata_keys = keys; } - - - std::set& embedding_state() { return _embedding_state; } - const std::set& embedding_state() const { return _embedding_state; } - void set_embedding_state(const std::set& embedding_state) { - _embedding_state = embedding_state; + std::string get_doc_path() const { + return std::any_cast(get_root_node()->_p_global_metadata->at(std::string(RAG_KEY_DOC_PATH))); } - - double relevance_score() const { return _relevance_score; } - void set_relevance_score(double relevance_score) { _relevance_score = relevance_score; } - - double similarity_score() const { return _similarity_score; } - void set_similarity_score(double similarity_score) { _similarity_score = similarity_score; } - - const std::variant& parent_node() const { return _p_parent_node; } - void set_parent_node(const std::variant& parent_node) { - _p_parent_node = parent_node; + void set_doc_path(const std::string& path) { + get_root_node()->_p_global_metadata->operator[](std::string(RAG_KEY_DOC_PATH)) = path; + } + auto py_get_children_uid() const { + auto children = py_get_children(); + std::unordered_map> children_uid; + for (auto& [group_name, nodes] : children) { + children_uid[group_name] = {}; + for (auto& node : nodes) + children_uid[group_name].push_back(node->get_uid()); + } + return children_uid; + } + std::string get_parent_uid() const { + auto parent = get_parent_node(); + if (parent == nullptr) return ""; + return parent->get_uid(); + } + std::set embedding_keys_undone(const std::set& keys_done) const { + if (keys_done.empty()) throw std::runtime_error("The ebmed_keys to be checked must be passed in.");; + std::set keys_undone; + for (const auto& key : keys_done) { + if (_embedding_vecs.find(key) == _embedding_vecs.end()) + keys_undone.insert(key); + } + return keys_undone; + } + void py_do_embedding(const std::unordered_map(const std::string&)>>& embedding_funcs) { + for (const auto& [key, func] : embedding_funcs) + _embedding_vecs[key] = func(get_text(MetadataMode::EMBED)); + } + void set_embedding_vec(const std::string& key, const std::vector& embedding_vec) { + _embedding_vecs[key] = embedding_vec; } - Children& children() { return _children; } - const Children& children() const { return _children; } - + bool operator==(const DocNode& other) const { return _uid == other._uid; } + bool operator!=(const DocNode& other) const { return _uid != other._uid; } }; } // namespace lazyllm diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 56cf1c455..7bf453f91 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -30,6 +30,29 @@ std::string JoinLines(const std::vector& lines, char delim = '\n') return out; } +template +std::vector ConcatVector(const std::vector& l, const std::vector& r) { + std::vector out; + out.reserve(l.size() + r.size()); + out.insert(out.end(), l.begin(), l.end()); + out.insert(out.end(), r.begin(), r.end()); + return out; +} + +template +std::set SetUnion(const std::set& l, const std::set& r) { + std::set out; + std::set_union(l.begin(), l.end(), r.begin(), r.end(), std::inserter(out, out.begin())); + return out; +} + +template +std::set SetDiff(const std::set& l, const std::set& r) { + std::set out; + std::set_difference(l.begin(), l.end(), r.begin(), r.end(), std::inserter(out, out.begin())); + return out; +} + std::string Trim(const std::string& value) { size_t start = 0; while (start < value.size() && std::isspace(static_cast(value[start]))) { diff --git a/csrc/core/src/doc_node.cpp b/csrc/core/src/doc_node.cpp index ca2d592f6..031d608c5 100644 --- a/csrc/core/src/doc_node.cpp +++ b/csrc/core/src/doc_node.cpp @@ -1,406 +1 @@ -#include -#include -#include -#include -#include -#include -#include -#include - #include "doc_node.hpp" -#include "utils.hpp" - -namespace lazyllm { - -const std::string& DocNode::content_text() const { - return _text; -} - -void DocNode::set_content(const std::vector& lines) { - _content_is_list = true; - _content = lines; - _text = JoinLines(lines); - content_hash(); -} - - - -const std::string& DocNode::get_text() const { - return _text_view; -} - -std::string DocNode::get_text_with_metadata(MetadataMode mode) const { - const std::string metadata_str = get_metadata_str(mode); - if (metadata_str.empty()) { - return _text; - } - if (_text.empty()) { - return metadata_str; - } - return metadata_str + "\n\n" + _text; -} - -void DocNode::set_store(std::shared_ptr store) { - _store = std::move(store); -} - -const std::shared_ptr& DocNode::store() const { - return _store; -} - -std::string DocNode::content_hash() const { - _content_hash = Sha256Hex(_text); - return _content_hash; -} - -DocNode::EmbeddingVec& DocNode::embedding() { - return _embedding; -} - -const DocNode::EmbeddingVec& DocNode::embedding() const { - return _embedding; -} - -void DocNode::set_embedding(const EmbeddingVec& embed) { - std::lock_guard lock(_embedding_mutex); - _embedding = embed; -} - -std::vector DocNode::has_missing_embedding(const std::vector& embed_keys) const { - std::vector missing; - if (embed_keys.empty()) { - return missing; - } - std::lock_guard lock(_embedding_mutex); - for (const auto& key : embed_keys) { - if (_embedding.find(key) == _embedding.end()) { - missing.push_back(key); - } - } - return missing; -} - -void DocNode::do_embedding(const std::unordered_map& embed) { - EmbeddingVec generated; - const std::string input = get_text_with_metadata(MetadataMode::EMBED); - for (const auto& item : embed) { - generated[item.first] = item.second(input, ""); - } - for (const auto& item : generated) { - _embedding[item.first] = item.second; - } -} - -void DocNode::set_embedding_value(const std::string& key, const std::vector& value) { - _embedding[key] = value; -} - -void DocNode::check_embedding_state(const std::string& embed_key) const { - while (true) { - { - if (_embedding.find(embed_key) != _embedding.end()) { - _embedding_state.erase(embed_key); - break; - } - } - std::this_thread::sleep_for(std::chrono::seconds(1)); - } -} - -DocNode* DocNode::parent() { - return _parent; -} - -const DocNode* DocNode::parent() const { - return _parent; -} - -void DocNode::set_parent(DocNode* parent) { - _parent = parent; -} - -DocNode::Children& DocNode::children() { - return _children; -} - -const DocNode::Children& DocNode::children() const { - return _children; -} - -void DocNode::set_children(const Children& children) { - _children = children; -} - -DocNode* DocNode::root_node() { - DocNode* node = this; - while (node->_parent != nullptr) { - node = node->_parent; - } - return node; -} - -const DocNode* DocNode::root_node() const { - const DocNode* node = this; - while (node->_parent != nullptr) { - node = node->_parent; - } - return node; -} - -bool DocNode::is_root_node() const { - return _parent == nullptr; -} - -DocNode::Metadata& DocNode::metadata() { - return _metadata; -} - -const DocNode::Metadata& DocNode::metadata() const { - return _metadata; -} - -void DocNode::set_metadata(const Metadata& metadata) { - _metadata = metadata; -} - -DocNode::Metadata& DocNode::global_metadata() { - return root_node()->_global_metadata; -} - -const DocNode::Metadata& DocNode::global_metadata() const { - return root_node()->_global_metadata; -} - -void DocNode::set_global_metadata(const Metadata& global_metadata) { - _global_metadata = global_metadata; -} - -std::vector DocNode::excluded_embed_metadata_keys() const { - std::set keys; - const DocNode* root = root_node(); - keys.insert(root->_excluded_embed_metadata_keys.begin(), root->_excluded_embed_metadata_keys.end()); - keys.insert(_excluded_embed_metadata_keys.begin(), _excluded_embed_metadata_keys.end()); - return std::vector(keys.begin(), keys.end()); -} - -void DocNode::set_excluded_embed_metadata_keys(const std::vector& keys) { - _excluded_embed_metadata_keys = keys; -} - -std::vector DocNode::excluded_llm_metadata_keys() const { - std::set keys; - const DocNode* root = root_node(); - keys.insert(root->_excluded_llm_metadata_keys.begin(), root->_excluded_llm_metadata_keys.end()); - keys.insert(_excluded_llm_metadata_keys.begin(), _excluded_llm_metadata_keys.end()); - return std::vector(keys.begin(), keys.end()); -} - -void DocNode::set_excluded_llm_metadata_keys(const std::vector& keys) { - _excluded_llm_metadata_keys = keys; -} - -std::string DocNode::docpath() const { - const auto& meta = global_metadata(); - const auto it = meta.find(kRagDocPath); - if (it == meta.end()) { - return ""; - } - return it->second; -} - -void DocNode::set_docpath(const std::string& path) { - if (!is_root_node()) { - throw std::runtime_error("Only root node can set docpath."); - } - global_metadata()[kRagDocPath] = path; -} - -std::string DocNode::get_children_str() const { - std::ostringstream oss; - oss << "{"; - bool first_group = true; - for (const auto& item : _children) { - if (!first_group) { - oss << ", "; - } - first_group = false; - oss << item.first << ": ["; - bool first_child = true; - for (const auto* node : item.second) { - if (!node) { - continue; - } - if (!first_child) { - oss << ", "; - } - first_child = false; - oss << node->uid(); - } - oss << "]"; - } - oss << "}"; - return oss.str(); -} - -std::string DocNode::get_parent_id() const { - return _parent ? _parent->uid() : ""; -} - -std::string DocNode::to_string() const { - std::ostringstream oss; - oss << "DocNode(id: " << _uid << ", group: " << _group << ", content: " << _text << ") parent: " - << get_parent_id() << ", children: " << get_children_str(); - return oss.str(); -} - -bool DocNode::operator==(const DocNode& other) const { - return _uid == other._uid; -} - -bool DocNode::operator!=(const DocNode& other) const { - return !(*this == other); -} - -std::size_t DocNode::hash() const { - return std::hash()(_uid); -} - -std::string DocNode::get_metadata_str(MetadataMode mode) const { - if (mode == MetadataMode::NONE) { - return ""; - } - std::set keys; - for (const auto& item : _metadata) { - keys.insert(item.first); - } - if (mode == MetadataMode::LLM) { - const auto excluded = excluded_llm_metadata_keys(); - for (const auto& key : excluded) { - keys.erase(key); - } - } else if (mode == MetadataMode::EMBED) { - const auto excluded = excluded_embed_metadata_keys(); - for (const auto& key : excluded) { - keys.erase(key); - } - } - std::ostringstream oss; - bool first = true; - for (const auto& key : keys) { - const auto it = _metadata.find(key); - if (it == _metadata.end()) { - continue; - } - if (!first) { - oss << "\n"; - } - first = false; - oss << key << ": " << it->second; - } - return oss.str(); -} - -std::string DocNode::get_content(MetadataMode mode) const { - if (mode == MetadataMode::LLM) { - return get_text_with_metadata(MetadataMode::LLM); - } - return get_text_with_metadata(mode); -} - -DocNode DocNode::with_score(double score) const { - DocNode node(*this); - node._relevance_score = score; - return node; -} - -DocNode DocNode::with_sim_score(double score) const { - DocNode node(*this); - node._similarity_score = score; - return node; -} - -double DocNode::relevance_score() const { - return _relevance_score; -} - -double DocNode::similarity_score() const { - return _similarity_score; -} - -QADocNode::QADocNode(const std::string& query, const std::string& answer) - : DocNode(query), _answer(Trim(answer)) {} - -QADocNode::QADocNode(const std::string& query, const std::string& answer, const std::string& uid, - const std::string& group) - : DocNode(query), _answer(Trim(answer)) { - if (!uid.empty()) { - _uid = uid; - } - _group = group; -} - -const std::string& QADocNode::answer() const { - return _answer; -} - -std::string QADocNode::get_text_with_metadata(MetadataMode mode) const { - if (mode == MetadataMode::LLM) { - std::ostringstream oss; - oss << "query:\n" << _text << "\nanswer\n" << _answer; - return oss.str(); - } - return DocNode::get_text_with_metadata(mode); -} - -ImageDocNode::ImageDocNode(const std::string& image_path) - : DocNode(image_path), _image_path(Trim(image_path)), _modality("image") { - set_text(_image_path); -} - -ImageDocNode::ImageDocNode(const std::string& image_path, const std::string& uid, const std::string& group) - : DocNode(image_path), _image_path(Trim(image_path)), _modality("image") { - if (!uid.empty()) { - _uid = uid; - } - _group = group; - set_text(_image_path); -} - -const std::string& ImageDocNode::image_path() const { - return _image_path; -} - -std::string ImageDocNode::get_content(MetadataMode mode) const { - if (mode == MetadataMode::EMBED) { - std::string file_bytes; - if (!ReadFileBinary(_image_path, &file_bytes)) { - return ""; - } - const std::string mime = ImageMimeType(_image_path); - const std::string base64 = Base64Encode(file_bytes); - if (mime.empty()) { - return base64; - } - return "data:" + mime + ";base64," + base64; - } - return _image_path; -} - -void ImageDocNode::do_embedding(const std::unordered_map& embed) { - EmbeddingVec generated; - const std::string input = get_content(MetadataMode::EMBED); - for (const auto& item : embed) { - generated[item.first] = item.second(input, _modality); - } - std::lock_guard lock(_embedding_mutex); - for (const auto& item : generated) { - _embedding[item.first] = item.second; - } -} - -std::string ImageDocNode::get_text_with_metadata(MetadataMode mode) const { - (void)mode; - return _image_path; -} - -} // namespace lazyllm From 0170a0e3f29a7e72111b46be19e3db318e535024 Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 4 Feb 2026 17:57:04 +0800 Subject: [PATCH 13/40] DocNode done --- csrc/binding/export_doc_node.cpp | 205 ++++++++++++++++++++++++++++++- csrc/core/include/doc_node.hpp | 2 +- csrc/core/include/utils.hpp | 96 +-------------- csrc/core/src/utils.cpp | 0 4 files changed, 210 insertions(+), 93 deletions(-) delete mode 100644 csrc/core/src/utils.cpp diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 3fa2ef2db..2715061e6 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -1,4 +1,7 @@ #include +#include +#include +#include #include "lazyllm.hpp" #include "document_store.hpp" @@ -15,7 +18,7 @@ lazyllm::DocNode init( std::optional> parent, py::object store, std::optional>> node_groups, + std::string, std::unordered_map>> node_groups, std::optional metadata, std::optional global_metadata, std::optional text @@ -82,7 +85,29 @@ lazyllm::DocNode init( return node; } +std::string DocNodeToString(const lazyllm::DocNode& node) { + py::dict d; + const auto children = node.py_get_children(); + for (const auto& [group, nodes] : children) { + py::list ids; + for (const auto* n : nodes) { + if (n) ids.append(n->get_uid()); + } + d[py::str(group)] = std::move(ids); + } + const std::string children_str = py::str(d).cast(); + return "DocNode(id: " + node.get_uid() + ", group: " + node.get_group_name() + + ", content: " + node.get_text(lazyllm::MetadataMode::NONE) + + ") parent: " + node.get_parent_uid() + ", children: " + children_str; +} + void exportDocNode(py::module& m) { + py::enum_(m, "MetadataMode") + .value("ALL", lazyllm::MetadataMode::ALL) + .value("EMBED", lazyllm::MetadataMode::EMBED) + .value("LLM", lazyllm::MetadataMode::LLM) + .value("NONE", lazyllm::MetadataMode::NONE); + py::class_(m, "DocNode") .def(py::init(&init), py::kw_only(), @@ -96,5 +121,181 @@ void exportDocNode(py::module& m) { py::arg("metadata") = py::none(), py::arg("global_metadata") = py::none(), py::arg("text") = py::none() - ); + ) + .def_property_readonly("uid", &lazyllm::DocNode::get_uid) + .def_property_readonly("group", &lazyllm::DocNode::get_group_name) + .def_property("content", + [](const lazyllm::DocNode& node) { + return std::string(node.get_text(lazyllm::MetadataMode::NONE)); + }, + [](lazyllm::DocNode& node, const std::variant>& content) { + if (const auto* content_str = std::get_if(&content)) { + node.set_root_text(std::move(*content_str)); + return; + } + else { + node.set_root_texts(std::get>(content)); + return; + } + } + ) + .def_property("number", + [](const lazyllm::DocNode& node) { + const auto it = node._metadata.find("lazyllm_store_num"); + if (it == node._metadata.end()) return 0; + return std::any_cast(it->second); + }, + [](lazyllm::DocNode& node, int value) { + node._metadata[std::string("lazyllm_store_num")] = value; + } + ) + .def_property_readonly("text", [](const lazyllm::DocNode& node) { return std::string(node.get_text()); }) + .def_property_readonly("content_hash", [](const lazyllm::DocNode& node) { + return lazyllm::to_hex(node.get_text_hash()); + }) + .def_property("embedding", + [](const lazyllm::DocNode& node) { return node._embedding_vecs; }, + [](lazyllm::DocNode& node, const lazyllm::DocNode::EmbeddingVecs& v) { + node._embedding_vecs = v; + } + ) + .def_property("parent", + [](const lazyllm::DocNode& node) { return node.get_parent_node(); }, + [](lazyllm::DocNode& node, lazyllm::DocNode* parent) { node.set_parent_node(parent); }, + py::return_value_policy::reference + ) + .def_property("children", + [](const lazyllm::DocNode& node) { return node.py_get_children(); }, + [](lazyllm::DocNode& node, const lazyllm::DocNode::Children& children) { + node.set_children(children); + } + ) + .def_property_readonly("root_node", + [](const lazyllm::DocNode& node) { return node.get_root_node(); }, + py::return_value_policy::reference + ) + .def_property_readonly("is_root_node", + [](const lazyllm::DocNode& node) { return node.get_parent_node() == nullptr; } + ) + .def_property("global_metadata", + [](const lazyllm::DocNode& node) { return *(node.get_root_node()->_p_global_metadata); }, + [](lazyllm::DocNode& node, const lazyllm::DocNode::Metadata& meta) { + node._p_global_metadata = std::make_shared(meta); + } + ) + .def_property("metadata", + [](const lazyllm::DocNode& node) { return node._metadata; }, + [](lazyllm::DocNode& node, const lazyllm::DocNode::Metadata& meta) { node._metadata = meta; } + ) + .def_property("excluded_embed_metadata_keys", + [](const lazyllm::DocNode& node) { return node.get_excluded_embed_metadata_keys(); }, + [](lazyllm::DocNode& node, const std::set& keys) { + node.set_excluded_embed_metadata_keys(keys); + } + ) + .def_property("excluded_llm_metadata_keys", + [](const lazyllm::DocNode& node) { return node.get_excluded_llm_metadata_keys(); }, + [](lazyllm::DocNode& node, const std::set& keys) { + node.set_excluded_llm_metadata_keys(keys); + } + ) + .def_property("docpath", + [](const lazyllm::DocNode& node) { return node.get_doc_path(); }, + [](lazyllm::DocNode& node, const std::string& path) { node.set_doc_path(path); } + ) + .def_property("embedding_state", + [](const lazyllm::DocNode& node) { return node._pending_embedding_keys; }, + [](lazyllm::DocNode& node, const std::set& keys) { + node._pending_embedding_keys = keys; + } + ) + .def_property("relevance_score", + [](const lazyllm::DocNode& node) { return node._relevance_score; }, + [](lazyllm::DocNode& node, double score) { node._relevance_score = score; } + ) + .def_property("similarity_score", + [](const lazyllm::DocNode& node) { return node._similarity_score; }, + [](lazyllm::DocNode& node, double score) { node._similarity_score = score; } + ) + .def("get_children_str", [](const lazyllm::DocNode& node) { + py::dict d; + const auto children = node.py_get_children(); + for (const auto& [group, nodes] : children) { + py::list ids; + for (const auto* n : nodes) if (n) ids.append(n->get_uid()); + d[py::str(group)] = std::move(ids); + } + return py::str(d); + }) + .def("get_parent_id", &lazyllm::DocNode::get_parent_uid) + .def("__str__", &DocNodeToString) + .def("__repr__", [](const lazyllm::DocNode& node) { + py::object cfg = py::module_::import("lazyllm").attr("config"); + py::object mode = py::module_::import("lazyllm").attr("Mode"); + py::object cfg_mode = cfg.attr("__getitem__")("mode"); + if (py::bool_(cfg_mode.equal(mode.attr("Debug")))) + return DocNodeToString(node); + return std::string(""; + }) + .def("__eq__", &lazyllm::DocNode::operator==, py::is_operator()) + .def("__hash__", [](const lazyllm::DocNode& node) { + return static_cast(std::hash{}(node.get_uid())); + }) + .def("__getstate__", [](const lazyllm::DocNode& node) { + py::dict st; + st["_uid"] = node.get_uid(); + st["_content"] = node.get_text(lazyllm::MetadataMode::NONE); + st["_group"] = node.get_group_name(); + st["_embedding"] = node._embedding_vecs; + st["_metadata"] = node._metadata; + st["_global_metadata"] = *(node._p_global_metadata); + st["_excluded_embed_metadata_keys"] = node.get_excluded_embed_metadata_keys(); + st["_excluded_llm_metadata_keys"] = node.get_excluded_llm_metadata_keys(); + st["_store"] = py::none(); + st["_node_groups"] = py::none(); + return st; + }) + .def("has_missing_embedding", [](const lazyllm::DocNode& node, + std::variant>& keys) { + if (const auto& single_key = std::get_if(&keys)) + return node.embedding_keys_undone({*single_key}); + else { + const auto& key_list = std::get>(keys); + return node.embedding_keys_undone(std::set(key_list.begin(), key_list.end())); + } + }) + .def("do_embedding", &lazyllm::DocNode::py_do_embedding, py::arg("embed")) + .def("set_embedding", [](lazyllm::DocNode& node, const std::string& key, const std::vector& value) { + node.set_embedding_vec(key, value); + }) + .def("check_embedding_state", [](lazyllm::DocNode& node, const std::string& key) { + while (true) { + if (node._embedding_vecs.find(key) != node._embedding_vecs.end()) { + node._pending_embedding_keys.erase(key); + break; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + }) + .def("get_content", [](const lazyllm::DocNode& node) { return node.get_text(lazyllm::MetadataMode::LLM); }) + .def("get_metadata_str", &lazyllm::DocNode::get_metadata_string, + py::arg("mode") = lazyllm::MetadataMode::ALL) + .def("get_text", &lazyllm::DocNode::get_text, py::arg("metadata_mode") = lazyllm::MetadataMode::NONE) + .def("to_dict", [](const lazyllm::DocNode& node) { + py::dict d; + d["content"] = node.get_text(lazyllm::MetadataMode::NONE); + d["embedding"] = node._embedding_vecs; + d["metadata"] = node._metadata; + return d; + }) + .def("with_score", [](const lazyllm::DocNode& node, double score) { + lazyllm::DocNode out = node; + out._relevance_score = score; + return out; + }) + .def("with_sim_score", [](const lazyllm::DocNode& node, double score) { + lazyllm::DocNode out = node; + out._similarity_score = score; + return out; + }); } diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 9d8dbae66..ddc0cdbd7 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -115,7 +115,7 @@ class DocNode { set_text_view(*_p_root_text); } void set_root_texts(const std::vector& texts) { set_root_text(JoinLines(texts)); } - size_t text_hash() const { return _text_hash; } + size_t get_text_hash() const { return _text_hash; } const DocNode* get_parent_node() const { return _p_parent_node; } void set_parent_node(DocNode* p_parent_node) { _p_parent_node = p_parent_node; } Children py_get_children() const { diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 7bf453f91..680955553 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -20,7 +20,7 @@ constexpr std::string_view RAG_KEY_DOC_CREATION_DATE = "creation_date"; constexpr std::string_view RAG_KEY_DOC_LAST_MODIFIED_DATE = "last_modified_date"; constexpr std::string_view RAG_KEY_DOC_LAST_ACCESSED_DATE = "last_accessed_date"; -std::string JoinLines(const std::vector& lines, char delim = '\n') { +inline std::string JoinLines(const std::vector& lines, char delim = '\n') { if (lines.empty()) return {}; std::string out = lines.front(); for (size_t i = 1; i < lines.size(); ++i) { @@ -53,19 +53,13 @@ std::set SetDiff(const std::set& l, const std::set& r) { return out; } -std::string Trim(const std::string& value) { - size_t start = 0; - while (start < value.size() && std::isspace(static_cast(value[start]))) { - ++start; - } - size_t end = value.size(); - while (end > start && std::isspace(static_cast(value[end - 1]))) { - --end; - } - return value.substr(start, end - start); +inline std::string to_hex(size_t v) { + std::ostringstream oss; + oss << std::hex << v; + return oss.str(); } -std::string GenerateUUID() { +inline std::string GenerateUUID() { static const char HEX_CHAR[] = "0123456789abcdef"; static const int SEGS[] = {8, 4, 4, 4, 12}; @@ -84,82 +78,4 @@ std::string GenerateUUID() { return out; } -std::string ToLower(std::string value) { - std::transform(value.begin(), value.end(), value.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - return value; -} - -std::string GetExtension(const std::string& path) { - const size_t dot = path.find_last_of('.'); - if (dot == std::string::npos || dot + 1 >= path.size()) { - return ""; - } - return ToLower(path.substr(dot + 1)); -} - -std::string ImageMimeType(const std::string& path) { - const std::string ext = GetExtension(path); - if (ext == "jpg" || ext == "jpeg" || ext == "jfif" || ext == "jpe") { - return "image/jpeg"; - } - if (ext == "png" || ext == "apng") { - return "image/png"; - } - if (ext == "gif") { - return "image/gif"; - } - if (ext == "bmp" || ext == "dib") { - return "image/bmp"; - } - if (ext == "tif" || ext == "tiff") { - return "image/tiff"; - } - if (ext == "webp") { - return "image/webp"; - } - if (ext == "ico") { - return "image/x-icon"; - } - if (ext == "icns") { - return "image/icns"; - } - return ""; -} - -bool ReadFileBinary(const std::string& path, std::string* out) { - std::ifstream file(path.c_str(), std::ios::binary); - if (!file) { - return false; - } - std::ostringstream buffer; - buffer << file.rdbuf(); - *out = buffer.str(); - return true; -} - -std::string Base64Encode(const std::string& data) { - static const char kBase64Chars[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(((data.size() + 2) / 3) * 4); - int val = 0; - int valb = -6; - for (unsigned char c : data) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - out.push_back(kBase64Chars[(val >> valb) & 0x3F]); - valb -= 6; - } - } - if (valb > -6) { - out.push_back(kBase64Chars[((val << 8) >> (valb + 8)) & 0x3F]); - } - while (out.size() % 4) { - out.push_back('='); - } - return out; -} - } // namespace lazyllm diff --git a/csrc/core/src/utils.cpp b/csrc/core/src/utils.cpp deleted file mode 100644 index e69de29bb..000000000 From 459cfd48b3d523d00e64b4f073caf98a703b8895 Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 5 Feb 2026 18:05:39 +0800 Subject: [PATCH 14/40] pending review --- csrc/CMakeLists.txt | 16 +- csrc/binding/export_doc_node.cpp | 4 +- csrc/binding/export_node_transform.cpp | 62 ++++ csrc/binding/export_text_spliter_base.cpp | 108 ++++++ csrc/binding/lazyllm.cpp | 2 + csrc/binding/lazyllm.hpp | 4 + csrc/cmake/third_party.cmake | 32 ++ csrc/core/include/node_transform.hpp | 98 ++++++ csrc/core/include/text_spliter_base.hpp | 392 ++++++++++++++++++++++ csrc/core/src/utils.cpp | 1 + 10 files changed, 704 insertions(+), 15 deletions(-) create mode 100644 csrc/binding/export_node_transform.cpp create mode 100644 csrc/binding/export_text_spliter_base.cpp create mode 100644 csrc/cmake/third_party.cmake create mode 100644 csrc/core/include/node_transform.hpp create mode 100644 csrc/core/include/text_spliter_base.hpp create mode 100644 csrc/core/src/utils.cpp diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index a824c99c5..014e6abe2 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -10,20 +10,7 @@ add_compile_options( ) # Third party libs -include(FetchContent) -find_package(Python3 COMPONENTS Interpreter Development REQUIRED) -find_package(pybind11 CONFIG REQUIRED) - -find_package(xxHash QUIET) -if (NOT TARGET xxhash) - FetchContent_Declare( - xxhash - GIT_REPOSITORY https://github.com/Cyan4973/xxHash.git - GIT_TAG v0.8.2 - ) - FetchContent_Populate(xxhash) - add_subdirectory(${xxhash_SOURCE_DIR}/cmake_unofficial ${xxhash_BINARY_DIR}) -endif() +include(cmake/third_party.cmake) # Config lazyllm_core lib with pure cpp code. file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS @@ -31,6 +18,7 @@ file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/include) target_link_libraries(lazyllm_core PUBLIC xxhash) +target_link_libraries(lazyllm_core PUBLIC sentencepiece) # Config lazyllm_adaptor lib which maintains callback invocations. file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 2715061e6..b3b663ac0 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -8,7 +8,7 @@ #include "doc_node.hpp" #include "utils.hpp" -namespace py = pybind11; +namespace { lazyllm::DocNode init( std::optional uid, @@ -101,6 +101,8 @@ std::string DocNodeToString(const lazyllm::DocNode& node) { + ") parent: " + node.get_parent_uid() + ", children: " + children_str; } +} // namespace + void exportDocNode(py::module& m) { py::enum_(m, "MetadataMode") .value("ALL", lazyllm::MetadataMode::ALL) diff --git a/csrc/binding/export_node_transform.cpp b/csrc/binding/export_node_transform.cpp new file mode 100644 index 000000000..6a7788734 --- /dev/null +++ b/csrc/binding/export_node_transform.cpp @@ -0,0 +1,62 @@ +#include "lazyllm.hpp" + +#include "doc_node.hpp" +#include "node_transform.hpp" + +#include +#include + +namespace { + +lazyllm::NodeTransform::TransformKwargs kwargs_to_any_map(const py::kwargs& kwargs) { + lazyllm::NodeTransform::TransformKwargs out; + for (auto item : kwargs) { + const auto key = py::cast(item.first); + out[key] = py::cast(item.second); + } + return out; +} + +std::vector cast_documents(py::object documents) { + std::vector docs; + if (py::isinstance(documents) && !py::isinstance(documents)) { + for (auto item : documents) docs.push_back(py::cast(item)); + } else { + docs.push_back(documents.cast()); + } + return docs; +} + +} // namespace + +void exportNodeTransform(py::module& m) { + py::class_(m, "NodeTransform") + .def(py::init(), py::arg("num_workers") = 0) + .def( + "batch_forward", + [](lazyllm::NodeTransform& self, + py::object documents, + const std::string& node_group, + py::kwargs kwargs) { + auto docs = cast_documents(documents); + auto kw = kwargs_to_any_map(kwargs); + return self.batch_forward(docs, node_group, kw); + }, + py::arg("documents"), + py::arg("node_group"), + py::kw_only(), + py::return_value_policy::reference + ) + .def( + "with_name", + [](lazyllm::NodeTransform& self, py::object name, bool copy) -> lazyllm::NodeTransform& { + if (name.is_none()) return self; + self.with_name(name.cast(), copy); + return self; + }, + py::arg("name"), + py::kw_only(), + py::arg("copy") = true, + py::return_value_policy::reference + ); +} diff --git a/csrc/binding/export_text_spliter_base.cpp b/csrc/binding/export_text_spliter_base.cpp new file mode 100644 index 000000000..6302b1a8d --- /dev/null +++ b/csrc/binding/export_text_spliter_base.cpp @@ -0,0 +1,108 @@ +#include "lazyllm.hpp" + +#include "adaptor_base_wrapper.hpp" +#include "text_spliter_base.hpp" + +#include +#include + +namespace { + +class PyTokenizer final : public lazyllm::Tokenizer, public lazyllm::AdaptorBaseWrapper { +public: + explicit PyTokenizer(const py::object& obj) : AdaptorBaseWrapper(obj) {} + + std::vector encode(const std::string& text) const override { + auto result = call("encode", {{"text", text}}); + return std::any_cast>(result); + } + + std::string decode(const std::vector& token_ids) const override { + auto result = call("decode", {{"token_ids", token_ids}}); + return std::any_cast(result); + } + +private: + std::any call_impl( + const std::string& func_name, + const py::object& func, + const std::unordered_map& args) const override + { + if (func.is_none()) { + throw std::runtime_error("Tokenizer missing method: " + func_name); + } + if (func_name == "encode") { + const auto& text = std::any_cast(args.at("text")); + py::object result = func(text); + return std::any(result.cast>()); + } + if (func_name == "decode") { + const auto& ids = std::any_cast&>(args.at("token_ids")); + py::object result = func(ids); + return std::any(result.cast()); + } + throw std::runtime_error("Unknown tokenizer method: " + func_name); + } +}; + +} // namespace + +void exportTextSpliterBase(py::module& m) { + py::class_(m, "_TextSplitterBase") + .def(py::init< + std::optional, + std::optional, + std::optional, + std::optional>(), + py::arg("chunk_size") = py::none(), + py::arg("overlap") = py::none(), + py::arg("num_workers") = py::none(), + py::arg("sentencepiece_model") = py::none() + ) + .def("split_text", &lazyllm::_TextSplitterBase::split_text, + py::arg("text"), py::arg("metadata_size")) + .def("from_sentencepiece_model", &lazyllm::_TextSplitterBase::from_sentencepiece_model, + py::arg("model_path"), py::return_value_policy::reference) + .def("from_tokenizer", + [](lazyllm::_TextSplitterBase& self, py::object tokenizer) -> lazyllm::_TextSplitterBase& { + auto adaptor = std::make_shared(tokenizer); + self.set_tokenizer(adaptor); + return self; + }, + py::arg("tokenizer"), + py::return_value_policy::reference + ) + .def_static("set_default", + [](py::kwargs kwargs) { + std::unordered_map params; + for (auto item : kwargs) + params[py::cast(item.first)] = py::cast(item.second); + lazyllm::_TextSplitterBase::set_default(params); + } + ) + .def_static("get_default", + [](py::object name) { + if (name.is_none()) return py::cast(lazyllm::_TextSplitterBase::get_default()); + auto opt = lazyllm::_TextSplitterBase::get_default(name.cast()); + if (!opt.has_value()) return py::none(); + return py::cast(*opt); + }, + py::arg("param_name") = py::none() + ) + .def_static("reset_default", &lazyllm::_TextSplitterBase::reset_default); + + py::class_(m, "_TokenTextSplitter") + .def(py::init< + std::optional, + std::optional, + std::optional, + std::optional>(), + py::arg("chunk_size") = py::none(), + py::arg("overlap") = py::none(), + py::arg("num_workers") = py::none(), + py::arg("sentencepiece_model") = py::none() + ); + + m.def("split_text_keep_separator", &lazyllm::split_text_keep_separator, + py::arg("text"), py::arg("separator")); +} diff --git a/csrc/binding/lazyllm.cpp b/csrc/binding/lazyllm.cpp index b5282cdc7..96217a732 100644 --- a/csrc/binding/lazyllm.cpp +++ b/csrc/binding/lazyllm.cpp @@ -16,4 +16,6 @@ PYBIND11_MODULE(lazyllm_cpp, m) { options.disable_function_signatures(); exportDocNode(m); + exportNodeTransform(m); + exportTextSpliterBase(m); } diff --git a/csrc/binding/lazyllm.hpp b/csrc/binding/lazyllm.hpp index 537514ef6..c9bc4c965 100644 --- a/csrc/binding/lazyllm.hpp +++ b/csrc/binding/lazyllm.hpp @@ -7,5 +7,9 @@ #include #include +namespace py = pybind11; + void exportAddDocStr(pybind11::module& m); void exportDocNode(pybind11::module& m); +void exportNodeTransform(pybind11::module& m); +void exportTextSpliterBase(pybind11::module& m); diff --git a/csrc/cmake/third_party.cmake b/csrc/cmake/third_party.cmake new file mode 100644 index 000000000..c5dd0cfe9 --- /dev/null +++ b/csrc/cmake/third_party.cmake @@ -0,0 +1,32 @@ +include(FetchContent) + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +find_package(pybind11 CONFIG REQUIRED) + +find_package(xxHash QUIET) +if (NOT TARGET xxhash) + FetchContent_Declare( + xxhash + GIT_REPOSITORY https://github.com/Cyan4973/xxHash.git + GIT_TAG v0.8.2 + ) + FetchContent_Populate(xxhash) + add_subdirectory(${xxhash_SOURCE_DIR}/cmake_unofficial ${xxhash_BINARY_DIR}) +endif() + +find_package(sentencepiece QUIET) +if (NOT TARGET sentencepiece AND NOT TARGET sentencepiece::sentencepiece AND NOT TARGET sentencepiece-static) + FetchContent_Declare( + sentencepiece + GIT_REPOSITORY https://github.com/google/sentencepiece.git + GIT_TAG v0.2.0 + ) + FetchContent_MakeAvailable(sentencepiece) +endif() +if (TARGET sentencepiece::sentencepiece) + add_library(sentencepiece ALIAS sentencepiece::sentencepiece) +elseif (TARGET sentencepiece) + add_library(sentencepiece ALIAS sentencepiece) +elseif (TARGET sentencepiece-static) + add_library(sentencepiece ALIAS sentencepiece-static) +endif() diff --git a/csrc/core/include/node_transform.hpp b/csrc/core/include/node_transform.hpp new file mode 100644 index 000000000..c039ea4ba --- /dev/null +++ b/csrc/core/include/node_transform.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "doc_node.hpp" + +namespace lazyllm { + +class NodeTransform { +public: + using TransformKwargs = std::unordered_map; + using TransformItem = std::variant; + using TransformResult = std::vector; + + explicit NodeTransform(int worker_num = 0) : _worker_num(worker_num) {} + virtual ~NodeTransform() = default; + + std::vector batch_forward( + const std::vector& documents, + const std::string& node_group, + const TransformKwargs& kwargs = {}) + { + std::vector results; + for (auto* node : documents) { + if (node == nullptr) continue; + auto children = node->py_get_children(); + if (children.find(node_group) != children.end()) continue; + + auto splits = forward(node, node_group, kwargs); + children[node_group] = splits; + node->set_children(children); + results.insert(results.end(), splits.begin(), splits.end()); + } + return results; + } + + std::vector batch_forward( + DocNode* document, + const std::string& node_group, + const TransformKwargs& kwargs = {}) + { + if (document == nullptr) return {}; + return batch_forward(std::vector{document}, node_group, kwargs); + } + + virtual TransformResult transform(DocNode* document, const TransformKwargs& kwargs) = 0; + + NodeTransform& with_name(const std::optional& name, bool /*copy*/ = true) { + if (name.has_value()) _name = *name; + return *this; + } + + const std::string& name() const { return _name; } + int worker_num() const { return _worker_num; } + +protected: + std::vector forward( + DocNode* node, + const std::string& node_group, + const TransformKwargs& kwargs) + { + TransformResult raw = transform(node, kwargs); + std::vector out; + out.reserve(raw.size()); + + for (auto& item : raw) { + if (auto* text = std::get_if(&item)) { + if (text->empty()) continue; + auto child = std::make_unique("", node_group, "", node); + child->set_root_text(std::move(*text)); + auto* ptr = child.get(); + _owned_nodes.emplace_back(std::move(child)); + out.push_back(ptr); + } else { + auto* child = std::get(item); + if (child == nullptr) continue; + child->set_parent_node(node); + out.push_back(child); + } + } + return out; + } + +protected: + int _worker_num = 0; + std::string _name; + std::vector> _owned_nodes; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/text_spliter_base.hpp b/csrc/core/include/text_spliter_base.hpp new file mode 100644 index 000000000..8a3a2b801 --- /dev/null +++ b/csrc/core/include/text_spliter_base.hpp @@ -0,0 +1,392 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "doc_node.hpp" +#include "node_transform.hpp" + +namespace lazyllm { + +struct _Split { + std::string text; + bool is_sentence = false; + int token_size = 0; +}; + +inline std::vector split_text_keep_separator( + const std::string& text, + const std::string& separator) +{ + if (separator.empty()) return text.empty() ? std::vector() : std::vector{text}; + if (text.find(separator) == std::string::npos) return {text}; + + std::vector result; + size_t start = 0; + const size_t sep_len = separator.size(); + while (start < text.size()) { + const size_t idx = text.find(separator, start); + if (idx == std::string::npos) { + result.emplace_back(text.substr(start)); + break; + } + if (idx == 0) { + start = sep_len; + continue; + } + result.emplace_back(text.substr(start, idx - start + sep_len)); + start = idx + sep_len; + } + return result; +} + +class Tokenizer { +public: + virtual ~Tokenizer() = default; + virtual std::vector encode(const std::string& text) const = 0; + virtual std::string decode(const std::vector& token_ids) const = 0; +}; + +class SentencePieceTokenizer final : public Tokenizer { +public: + SentencePieceTokenizer() = default; + explicit SentencePieceTokenizer(const std::string& model_path) { load(model_path); } + + bool load(const std::string& model_path) { + auto status = _processor.Load(model_path); + if (!status.ok()) return false; + _loaded = true; + return true; + } + + std::vector encode(const std::string& text) const override { + ensure_loaded(); + std::vector ids; + auto status = _processor.Encode(text, &ids); + if (!status.ok()) throw std::runtime_error(status.ToString()); + return ids; + } + + std::string decode(const std::vector& token_ids) const override { + ensure_loaded(); + std::string text; + auto status = _processor.Decode(token_ids, &text); + if (!status.ok()) throw std::runtime_error(status.ToString()); + return text; + } + +private: + void ensure_loaded() const { + if (!_loaded) throw std::runtime_error("SentencePiece model not loaded."); + } + +private: + sentencepiece::SentencePieceProcessor _processor; + bool _loaded = false; +}; + +class _TextSplitterBase : public NodeTransform { +public: + using SplitFn = std::function(const std::string&)>; + + explicit _TextSplitterBase( + std::optional chunk_size = std::nullopt, + std::optional overlap = std::nullopt, + std::optional num_workers = std::nullopt, + std::optional sentencepiece_model = std::nullopt) + : NodeTransform(get_param_value("num_workers", num_workers, 0)), + _chunk_size(get_param_value("chunk_size", chunk_size, 1024)), + _overlap(get_param_value("overlap", overlap, 200)) + { + if (_overlap > _chunk_size) { + throw std::runtime_error( + "Got a larger chunk overlap than chunk size, should be smaller."); + } + if (_chunk_size <= 0 || _overlap < 0) + throw std::runtime_error("chunk size should > 0 and overlap should >= 0."); + + if (sentencepiece_model.has_value()) { + from_sentencepiece_model(*sentencepiece_model); + } else { + const char* env_model = std::getenv("LAZYLLM_SENTENCEPIECE_MODEL"); + if (env_model && *env_model != '\0') { + from_sentencepiece_model(std::string(env_model)); + } + } + } + + static void set_default(const std::unordered_map& params) { + std::lock_guard guard(default_params_lock()); + auto& defaults = default_params(); + for (const auto& [key, value] : params) defaults[key] = value; + } + + static std::unordered_map get_default() { + std::lock_guard guard(default_params_lock()); + return default_params(); + } + + static std::optional get_default(const std::string& param_name) { + std::lock_guard guard(default_params_lock()); + auto& params = default_params(); + auto it = params.find(param_name); + if (it == params.end()) return std::nullopt; + return it->second; + } + + static void reset_default() { + std::lock_guard guard(default_params_lock()); + default_params().clear(); + } + + _TextSplitterBase& from_sentencepiece_model(const std::string& model_path) { + auto sp = std::make_shared(); + if (!sp->load(model_path)) + throw std::runtime_error("Failed to load sentencepiece model: " + model_path); + _tokenizer = std::move(sp); + return *this; + } + + _TextSplitterBase& set_tokenizer(const std::shared_ptr& tokenizer) { + _tokenizer = tokenizer; + return *this; + } + + std::vector split_text(const std::string& text, int metadata_size) { + if (text.empty()) return {""}; + const int effective_chunk_size = _chunk_size - metadata_size; + if (effective_chunk_size <= 0) { + throw std::runtime_error( + "Metadata length is longer than chunk size."); + } + auto splits = _split(text, effective_chunk_size); + return _merge(splits, effective_chunk_size); + } + + TransformResult transform(DocNode* node, const TransformKwargs& /*kwargs*/) override { + if (node == nullptr) return {}; + auto chunks = split_text(node->get_text(), _get_metadata_size(node)); + TransformResult out; + out.reserve(chunks.size()); + for (auto& chunk : chunks) out.emplace_back(std::move(chunk)); + return out; + } + + virtual void set_split_fns( + const std::vector& /*split_fns*/, + const std::optional>& /*sub_split_fns*/ = std::nullopt) {} + + virtual void add_split_fn(const SplitFn& /*split_fn*/, const std::optional& /*index*/ = std::nullopt) {} + + virtual void clear_split_fns() {} + +protected: + virtual std::vector<_Split> _split(const std::string& text, int chunk_size) { + const int token_size = _token_size(text); + if (token_size <= chunk_size) return {_Split{text, true, token_size}}; + + auto [text_splits, is_sentence] = _get_splits_by_fns(text); + std::vector<_Split> results; + for (const auto& segment : text_splits) { + const int seg_token_size = _token_size(segment); + if (seg_token_size <= chunk_size) { + results.push_back(_Split{segment, is_sentence, seg_token_size}); + } else { + auto sub_results = _split(segment, chunk_size); + results.insert(results.end(), sub_results.begin(), sub_results.end()); + } + } + return results; + } + + virtual std::vector _merge(std::vector<_Split> splits, int chunk_size) { + if (splits.empty()) return {}; + if (splits.size() == 1) return {splits.front().text}; + + _Split end_split = splits.back(); + if (end_split.token_size == chunk_size && _overlap > 0) { + splits.pop_back(); + auto text_tokens = encode(end_split.text); + const size_t half = text_tokens.size() / 2; + std::vector p_tokens(text_tokens.begin(), text_tokens.begin() + half); + std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); + std::string p_text = decode(p_tokens); + std::string n_text = decode(n_tokens); + splits.push_back(_Split{p_text, end_split.is_sentence, _token_size(p_text)}); + splits.push_back(_Split{n_text, end_split.is_sentence, _token_size(n_text)}); + end_split = splits.back(); + } + + std::vector result; + for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { + _Split start_split = splits[static_cast(idx)]; + if (start_split.token_size <= _overlap && + end_split.token_size <= chunk_size - _overlap) { + const bool is_sentence = start_split.is_sentence && end_split.is_sentence; + const int token_size = start_split.token_size + end_split.token_size; + end_split = _Split{start_split.text + end_split.text, is_sentence, token_size}; + continue; + } + + if (end_split.token_size > chunk_size) { + throw std::runtime_error("split token size is greater than chunk size."); + } + + const int remaining_space = chunk_size - end_split.token_size; + const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); + if (overlap_len > 0) { + auto start_tokens = encode(start_split.text); + std::vector overlap_tokens( + start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = decode(overlap_tokens); + end_split = _Split{overlap_text + end_split.text, end_split.is_sentence, + end_split.token_size + overlap_len}; + } + + result.insert(result.begin(), end_split.text); + end_split = start_split; + } + + result.insert(result.begin(), end_split.text); + return result; + } + + virtual std::pair, bool> _get_splits_by_fns( + const std::string& text) const + { + auto splits = split_text_keep_separator(text, "\n\n\n"); + if (splits.size() > 1) return {splits, true}; + + splits = regex_find_all(text, R"([^.!?。?!]+[.!?。?!]?)"); + if (splits.size() > 1) return {splits, true}; + + splits = regex_find_all(text, R"([^,.;。?!]+[,.;。?!]?)"); + if (splits.size() > 1) return {splits, false}; + + splits = split_text_keep_separator(text, " "); + if (splits.size() > 1) return {splits, false}; + + return {split_to_chars(text), false}; + } + + int _get_metadata_size(const DocNode* node) const { + return std::max( + _token_size(node->get_metadata_string(MetadataMode::EMBED)), + _token_size(node->get_metadata_string(MetadataMode::LLM))); + } + + int _token_size(const std::string& text) const { + return static_cast(encode(text).size()); + } + + std::vector encode(const std::string& text) const { + if (!_tokenizer) throw std::runtime_error("Tokenizer not initialized."); + return _tokenizer->encode(text); + } + + std::string decode(const std::vector& token_ids) const { + if (!_tokenizer) throw std::runtime_error("Tokenizer not initialized."); + return _tokenizer->decode(token_ids); + } + + static std::vector regex_find_all( + const std::string& text, const std::string& pattern) + { + std::regex re(pattern); + std::vector out; + for (auto it = std::sregex_iterator(text.begin(), text.end(), re); + it != std::sregex_iterator(); ++it) { + out.emplace_back(it->str()); + } + if (out.empty()) out.emplace_back(text); + return out; + } + + static std::vector split_to_chars(const std::string& text) { + std::vector out; + out.reserve(text.size()); + for (char c : text) out.emplace_back(1, c); + return out; + } + + static std::unordered_map& default_params() { + static std::unordered_map params; + return params; + } + + static std::recursive_mutex& default_params_lock() { + static std::recursive_mutex lock; + return lock; + } + + static int get_param_value( + const std::string& param_name, + const std::optional& value, + int default_value) + { + if (value.has_value()) return *value; + std::lock_guard guard(default_params_lock()); + auto& params = default_params(); + auto it = params.find(param_name); + if (it != params.end()) return it->second; + return default_value; + } + +protected: + int _chunk_size = 1024; + int _overlap = 200; + std::shared_ptr _tokenizer; +}; + +class _TokenTextSplitter : public _TextSplitterBase { +public: + explicit _TokenTextSplitter( + std::optional chunk_size = std::nullopt, + std::optional overlap = std::nullopt, + std::optional num_workers = std::nullopt, + std::optional sentencepiece_model = std::nullopt) + : _TextSplitterBase(chunk_size, overlap, num_workers, sentencepiece_model) {} + +protected: + std::vector<_Split> _split(const std::string& text, int chunk_size) override { + const int token_size = _token_size(text); + if (token_size <= chunk_size) return {_Split{text, true, token_size}}; + + std::vector<_Split> results; + auto tokens = encode(text); + size_t start_idx = 0; + size_t end_idx = std::min(start_idx + static_cast(chunk_size), tokens.size()); + while (start_idx < tokens.size()) { + std::vector chunk_tokens(tokens.begin() + start_idx, tokens.begin() + end_idx); + results.push_back(_Split{decode(chunk_tokens), true, static_cast(chunk_tokens.size())}); + if (end_idx >= tokens.size()) break; + start_idx = std::min(start_idx + static_cast(chunk_size - _overlap), tokens.size()); + end_idx = std::min(start_idx + static_cast(chunk_size), tokens.size()); + } + return results; + } + + std::vector _merge(std::vector<_Split> splits, int /*chunk_size*/) override { + std::vector out; + out.reserve(splits.size()); + for (auto& split : splits) out.emplace_back(std::move(split.text)); + return out; + } +}; + +} // namespace lazyllm diff --git a/csrc/core/src/utils.cpp b/csrc/core/src/utils.cpp new file mode 100644 index 000000000..0ee624c51 --- /dev/null +++ b/csrc/core/src/utils.cpp @@ -0,0 +1 @@ +#include "utils.hpp" From 5ea167c50897b665c00772bb9a449bcda5ca4e6c Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 6 Feb 2026 16:15:30 +0800 Subject: [PATCH 15/40] NodeTransform done --- csrc/binding/export_doc_node.cpp | 6 +- csrc/core/include/doc_node.hpp | 14 +++- csrc/core/include/node_transform.hpp | 107 ++++++++++++--------------- csrc/core/include/thread_pool.hpp | 99 +++++++++++++++++++++++++ 4 files changed, 161 insertions(+), 65 deletions(-) create mode 100644 csrc/core/include/thread_pool.hpp diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index b3b663ac0..24c82aa8d 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -96,7 +96,7 @@ std::string DocNodeToString(const lazyllm::DocNode& node) { d[py::str(group)] = std::move(ids); } const std::string children_str = py::str(d).cast(); - return "DocNode(id: " + node.get_uid() + ", group: " + node.get_group_name() + return "DocNode(id: " + node.get_uid() + ", group: " + node._group_name + ", content: " + node.get_text(lazyllm::MetadataMode::NONE) + ") parent: " + node.get_parent_uid() + ", children: " + children_str; } @@ -125,7 +125,7 @@ void exportDocNode(py::module& m) { py::arg("text") = py::none() ) .def_property_readonly("uid", &lazyllm::DocNode::get_uid) - .def_property_readonly("group", &lazyllm::DocNode::get_group_name) + .def_property_readonly("group", [](const lazyllm::DocNode& node) { return node._group_name; }) .def_property("content", [](const lazyllm::DocNode& node) { return std::string(node.get_text(lazyllm::MetadataMode::NONE)); @@ -247,7 +247,7 @@ void exportDocNode(py::module& m) { py::dict st; st["_uid"] = node.get_uid(); st["_content"] = node.get_text(lazyllm::MetadataMode::NONE); - st["_group"] = node.get_group_name(); + st["_group"] = node._group_name; st["_embedding"] = node._embedding_vecs; st["_metadata"] = node._metadata; st["_global_metadata"] = *(node._p_global_metadata); diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index ddc0cdbd7..d91e18b09 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -32,12 +32,12 @@ class DocNode { std::set _pending_embedding_keys = {}; double _relevance_score = .0; double _similarity_score = .0; + std::string _group_name; private: std::string_view _text_view; std::shared_ptr _p_root_text = nullptr; std::vector _root_texts = {}; - std::string _group_name; std::string _uid; mutable size_t _text_hash = 0; @@ -52,7 +52,7 @@ class DocNode { DocNode() = delete; explicit DocNode( const std::string_view& text_view, - const std::string& group_name, + const std::string& group_name = "", const std::string& uid = "", const DocNode* p_parent_node = nullptr, const Metadata& metadata = {}, @@ -74,6 +74,10 @@ class DocNode { size_t evaluate_text_hash() const { return static_cast(XXH64(_text_view.data(), _text_view.size(), 0)); } + bool is_children_group_exists(const std::string& group_name) const { + if (_children.empty()) return false; + return _children.find(group_name) != _children.end(); + } // Getter and Setter const DocNode* get_root_node() const { @@ -82,7 +86,6 @@ class DocNode { } void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } const std::string& get_uid() const { return _uid; } - const std::string& get_group_name() const { return _group_name; } const std::string_view& get_text_view() const { return _text_view; } void set_text_view(const std::string_view& text_view) { _text_view = text_view; @@ -117,7 +120,7 @@ class DocNode { void set_root_texts(const std::vector& texts) { set_root_text(JoinLines(texts)); } size_t get_text_hash() const { return _text_hash; } const DocNode* get_parent_node() const { return _p_parent_node; } - void set_parent_node(DocNode* p_parent_node) { _p_parent_node = p_parent_node; } + void set_parent_node(const DocNode* p_parent_node) { _p_parent_node = p_parent_node; } Children py_get_children() const { if (!_children.empty()) return _children; if (_p_store == nullptr) return Children(); @@ -125,6 +128,9 @@ class DocNode { return _children; } void set_children(const Children& children) { _children = children; } + void set_children_group(const std::string& group_name, const std::vector& children_group) { + _children[group_name] = children_group; + } std::set get_excluded_embed_metadata_keys() const { return SetUnion(get_root_node()->get_excluded_embed_metadata_keys(), _excluded_embed_metadata_keys); } diff --git a/csrc/core/include/node_transform.hpp b/csrc/core/include/node_transform.hpp index c039ea4ba..e4f1f9b9a 100644 --- a/csrc/core/include/node_transform.hpp +++ b/csrc/core/include/node_transform.hpp @@ -1,7 +1,9 @@ #pragma once #include +#include #include +#include #include #include #include @@ -11,88 +13,77 @@ #include #include "doc_node.hpp" +#include "thread_pool.hpp" namespace lazyllm { class NodeTransform { public: - using TransformKwargs = std::unordered_map; - using TransformItem = std::variant; - using TransformResult = std::vector; + std::string _name = ""; explicit NodeTransform(int worker_num = 0) : _worker_num(worker_num) {} virtual ~NodeTransform() = default; - std::vector batch_forward( - const std::vector& documents, - const std::string& node_group, - const TransformKwargs& kwargs = {}) - { - std::vector results; - for (auto* node : documents) { - if (node == nullptr) continue; - auto children = node->py_get_children(); - if (children.find(node_group) != children.end()) continue; - - auto splits = forward(node, node_group, kwargs); - children[node_group] = splits; - node->set_children(children); - results.insert(results.end(), splits.begin(), splits.end()); - } - return results; - } + virtual std::vector transform(const DocNode*) const = 0; + std::vector operator()(const DocNode& node) const { return transform(&node); } - std::vector batch_forward( - DocNode* document, - const std::string& node_group, - const TransformKwargs& kwargs = {}) - { - if (document == nullptr) return {}; - return batch_forward(std::vector{document}, node_group, kwargs); - } + std::vector batch_forward(std::vector& nodes, const std::string& node_group_name) { + std::vector whole_nodes; + if (nodes.empty()) return whole_nodes; - virtual TransformResult transform(DocNode* document, const TransformKwargs& kwargs) = 0; + if (_worker_num > 0) { + ThreadPool pool(static_cast(_worker_num)); + std::vector>> futures; + futures.reserve(nodes.size()); + for (auto* p_node : nodes) { + futures.emplace_back(pool.enqueue( + [this, p_node, node_group_name] { return forward(p_node, node_group_name); })); + } - NodeTransform& with_name(const std::optional& name, bool /*copy*/ = true) { - if (name.has_value()) _name = *name; - return *this; + for (auto& fut : futures) { + auto parts = fut.get(); + whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); + } + } else { + for (auto* p_node : nodes) { + auto parts = forward(p_node, node_group_name); + whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); + } + } + + return whole_nodes; } - const std::string& name() const { return _name; } int worker_num() const { return _worker_num; } -protected: - std::vector forward( - DocNode* node, - const std::string& node_group, - const TransformKwargs& kwargs) - { - TransformResult raw = transform(node, kwargs); +private: + std::vector forward(DocNode* p_node, const std::string& node_group_name) { + if (p_node->is_children_group_exists(node_group_name)) return {}; + + auto raw_nodes = transform(p_node); std::vector out; - out.reserve(raw.size()); - - for (auto& item : raw) { - if (auto* text = std::get_if(&item)) { - if (text->empty()) continue; - auto child = std::make_unique("", node_group, "", node); - child->set_root_text(std::move(*text)); - auto* ptr = child.get(); + out.reserve(raw_nodes.size()); + + for (auto& node_ : raw_nodes) { + node_.set_parent_node(p_node); + node_._group_name = node_group_name; + auto child = std::make_unique(std::move(node_)); + auto* ptr = child.get(); + { + std::lock_guard lock(_owned_nodes_mutex); _owned_nodes.emplace_back(std::move(child)); - out.push_back(ptr); - } else { - auto* child = std::get(item); - if (child == nullptr) continue; - child->set_parent_node(node); - out.push_back(child); } + out.push_back(ptr); } + p_node->set_children_group(node_group_name, out); + return out; } -protected: - int _worker_num = 0; - std::string _name; std::vector> _owned_nodes; + std::mutex _owned_nodes_mutex; + int _worker_num = 0; + bool _support_rich = false; }; } // namespace lazyllm diff --git a/csrc/core/include/thread_pool.hpp b/csrc/core/include/thread_pool.hpp new file mode 100644 index 000000000..ca9c2a9c8 --- /dev/null +++ b/csrc/core/include/thread_pool.hpp @@ -0,0 +1,99 @@ +// https://github.com/progschj/ThreadPool + +#ifndef THREAD_POOL_H +#define THREAD_POOL_H +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false) +{ + for(size_t i = 0;i task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this]{ return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if(stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task](){ (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for(std::thread &worker: workers) + worker.join(); +} + +#endif From e4070f8d6974927601fd807958a3d070513105e4 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 6 Feb 2026 16:16:34 +0800 Subject: [PATCH 16/40] rename --- ...export_text_spliter_base.cpp => export_text_splitter_base.cpp} | 0 .../include/{text_spliter_base.hpp => text_splitter_base.hpp} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename csrc/binding/{export_text_spliter_base.cpp => export_text_splitter_base.cpp} (100%) rename csrc/core/include/{text_spliter_base.hpp => text_splitter_base.hpp} (100%) diff --git a/csrc/binding/export_text_spliter_base.cpp b/csrc/binding/export_text_splitter_base.cpp similarity index 100% rename from csrc/binding/export_text_spliter_base.cpp rename to csrc/binding/export_text_splitter_base.cpp diff --git a/csrc/core/include/text_spliter_base.hpp b/csrc/core/include/text_splitter_base.hpp similarity index 100% rename from csrc/core/include/text_spliter_base.hpp rename to csrc/core/include/text_splitter_base.hpp From 6017ffa4553cd6c58c43ad3c5a7dd2dfe26a2416 Mon Sep 17 00:00:00 2001 From: yzh Date: Sat, 7 Feb 2026 13:39:04 +0800 Subject: [PATCH 17/40] save --- csrc/CMakeLists.txt | 1 + csrc/binding/export_node_transform.cpp | 20 +- csrc/binding/export_text_splitter_base.cpp | 18 +- csrc/core/include/text_splitter_base.hpp | 240 +++++---------------- csrc/core/include/tokenizer.hpp | 42 ++++ csrc/core/src/text_splitter_base.cpp | 3 + 6 files changed, 111 insertions(+), 213 deletions(-) create mode 100644 csrc/core/include/tokenizer.hpp create mode 100644 csrc/core/src/text_splitter_base.cpp diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 014e6abe2..c873210dd 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/include) target_link_libraries(lazyllm_core PUBLIC xxhash) target_link_libraries(lazyllm_core PUBLIC sentencepiece) +target_include_directories(lazyllm_core PUBLIC ${sentencepiece_SOURCE_DIR}/src) # Config lazyllm_adaptor lib which maintains callback invocations. file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS diff --git a/csrc/binding/export_node_transform.cpp b/csrc/binding/export_node_transform.cpp index 6a7788734..82d527553 100644 --- a/csrc/binding/export_node_transform.cpp +++ b/csrc/binding/export_node_transform.cpp @@ -8,18 +8,9 @@ namespace { -lazyllm::NodeTransform::TransformKwargs kwargs_to_any_map(const py::kwargs& kwargs) { - lazyllm::NodeTransform::TransformKwargs out; - for (auto item : kwargs) { - const auto key = py::cast(item.first); - out[key] = py::cast(item.second); - } - return out; -} - std::vector cast_documents(py::object documents) { std::vector docs; - if (py::isinstance(documents) && !py::isinstance(documents)) { + if (py::isinstance(documents)) { for (auto item : documents) docs.push_back(py::cast(item)); } else { docs.push_back(documents.cast()); @@ -32,15 +23,12 @@ std::vector cast_documents(py::object documents) { void exportNodeTransform(py::module& m) { py::class_(m, "NodeTransform") .def(py::init(), py::arg("num_workers") = 0) - .def( - "batch_forward", + .def("batch_forward", [](lazyllm::NodeTransform& self, py::object documents, - const std::string& node_group, - py::kwargs kwargs) { + const std::string& node_group) { auto docs = cast_documents(documents); - auto kw = kwargs_to_any_map(kwargs); - return self.batch_forward(docs, node_group, kw); + return self.batch_forward(docs, node_group); }, py::arg("documents"), py::arg("node_group"), diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp index 6302b1a8d..b1cb9d18a 100644 --- a/csrc/binding/export_text_splitter_base.cpp +++ b/csrc/binding/export_text_splitter_base.cpp @@ -48,7 +48,7 @@ class PyTokenizer final : public lazyllm::Tokenizer, public lazyllm::AdaptorBase } // namespace void exportTextSpliterBase(py::module& m) { - py::class_(m, "_TextSplitterBase") + py::class_(m, "TextSplitterBase") .def(py::init< std::optional, std::optional, @@ -59,12 +59,12 @@ void exportTextSpliterBase(py::module& m) { py::arg("num_workers") = py::none(), py::arg("sentencepiece_model") = py::none() ) - .def("split_text", &lazyllm::_TextSplitterBase::split_text, + .def("split_text", &lazyllm::TextSplitterBase::split_text, py::arg("text"), py::arg("metadata_size")) - .def("from_sentencepiece_model", &lazyllm::_TextSplitterBase::from_sentencepiece_model, + .def("from_sentencepiece_model", &lazyllm::TextSplitterBase::from_sentencepiece_model, py::arg("model_path"), py::return_value_policy::reference) .def("from_tokenizer", - [](lazyllm::_TextSplitterBase& self, py::object tokenizer) -> lazyllm::_TextSplitterBase& { + [](lazyllm::TextSplitterBase& self, py::object tokenizer) -> lazyllm::TextSplitterBase& { auto adaptor = std::make_shared(tokenizer); self.set_tokenizer(adaptor); return self; @@ -77,21 +77,21 @@ void exportTextSpliterBase(py::module& m) { std::unordered_map params; for (auto item : kwargs) params[py::cast(item.first)] = py::cast(item.second); - lazyllm::_TextSplitterBase::set_default(params); + lazyllm::TextSplitterBase::set_default(params); } ) .def_static("get_default", [](py::object name) { - if (name.is_none()) return py::cast(lazyllm::_TextSplitterBase::get_default()); - auto opt = lazyllm::_TextSplitterBase::get_default(name.cast()); + if (name.is_none()) return py::cast(lazyllm::TextSplitterBase::get_default()); + auto opt = lazyllm::TextSplitterBase::get_default(name.cast()); if (!opt.has_value()) return py::none(); return py::cast(*opt); }, py::arg("param_name") = py::none() ) - .def_static("reset_default", &lazyllm::_TextSplitterBase::reset_default); + .def_static("reset_default", &lazyllm::TextSplitterBase::reset_default); - py::class_(m, "_TokenTextSplitter") + py::class_(m, "_TokenTextSplitter") .def(py::init< std::optional, std::optional, diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 8a3a2b801..7a178b792 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -16,15 +16,14 @@ #include #include -#include - #include "doc_node.hpp" #include "node_transform.hpp" +#include "tokenizer.hpp" namespace lazyllm { -struct _Split { - std::string text; +struct ChunkUnit { + std::string_view view; bool is_sentence = false; int token_size = 0; }; @@ -55,137 +54,48 @@ inline std::vector split_text_keep_separator( return result; } -class Tokenizer { -public: - virtual ~Tokenizer() = default; - virtual std::vector encode(const std::string& text) const = 0; - virtual std::string decode(const std::vector& token_ids) const = 0; -}; - -class SentencePieceTokenizer final : public Tokenizer { -public: - SentencePieceTokenizer() = default; - explicit SentencePieceTokenizer(const std::string& model_path) { load(model_path); } - - bool load(const std::string& model_path) { - auto status = _processor.Load(model_path); - if (!status.ok()) return false; - _loaded = true; - return true; - } - - std::vector encode(const std::string& text) const override { - ensure_loaded(); - std::vector ids; - auto status = _processor.Encode(text, &ids); - if (!status.ok()) throw std::runtime_error(status.ToString()); - return ids; - } - - std::string decode(const std::vector& token_ids) const override { - ensure_loaded(); - std::string text; - auto status = _processor.Decode(token_ids, &text); - if (!status.ok()) throw std::runtime_error(status.ToString()); - return text; - } - -private: - void ensure_loaded() const { - if (!_loaded) throw std::runtime_error("SentencePiece model not loaded."); - } - -private: - sentencepiece::SentencePieceProcessor _processor; - bool _loaded = false; -}; - -class _TextSplitterBase : public NodeTransform { +class TextSplitterBase : public NodeTransform { public: using SplitFn = std::function(const std::string&)>; - explicit _TextSplitterBase( - std::optional chunk_size = std::nullopt, - std::optional overlap = std::nullopt, - std::optional num_workers = std::nullopt, - std::optional sentencepiece_model = std::nullopt) - : NodeTransform(get_param_value("num_workers", num_workers, 0)), - _chunk_size(get_param_value("chunk_size", chunk_size, 1024)), - _overlap(get_param_value("overlap", overlap, 200)) + explicit TextSplitterBase(unsigned chunk_size = 1024, unsigned overlap = 200, unsigned worker_num = 0, + const std::string& model_path = "") + : NodeTransform(worker_num), _chunk_size(chunk_size), _overlap(overlap) { - if (_overlap > _chunk_size) { - throw std::runtime_error( - "Got a larger chunk overlap than chunk size, should be smaller."); - } - if (_chunk_size <= 0 || _overlap < 0) - throw std::runtime_error("chunk size should > 0 and overlap should >= 0."); - - if (sentencepiece_model.has_value()) { - from_sentencepiece_model(*sentencepiece_model); - } else { - const char* env_model = std::getenv("LAZYLLM_SENTENCEPIECE_MODEL"); - if (env_model && *env_model != '\0') { - from_sentencepiece_model(std::string(env_model)); - } - } - } - - static void set_default(const std::unordered_map& params) { - std::lock_guard guard(default_params_lock()); - auto& defaults = default_params(); - for (const auto& [key, value] : params) defaults[key] = value; - } - - static std::unordered_map get_default() { - std::lock_guard guard(default_params_lock()); - return default_params(); - } + if (_overlap > _chunk_size) + throw std::runtime_error("'overlap' should be less than 'chunk_size'."); + if (_chunk_size == 0) + throw std::runtime_error("'chunk_size' should > 0"); - static std::optional get_default(const std::string& param_name) { - std::lock_guard guard(default_params_lock()); - auto& params = default_params(); - auto it = params.find(param_name); - if (it == params.end()) return std::nullopt; - return it->second; + _tokenizer = std::make_shared(model_path); } - static void reset_default() { - std::lock_guard guard(default_params_lock()); - default_params().clear(); - } - - _TextSplitterBase& from_sentencepiece_model(const std::string& model_path) { - auto sp = std::make_shared(); - if (!sp->load(model_path)) - throw std::runtime_error("Failed to load sentencepiece model: " + model_path); - _tokenizer = std::move(sp); - return *this; - } - - _TextSplitterBase& set_tokenizer(const std::shared_ptr& tokenizer) { - _tokenizer = tokenizer; - return *this; + std::vector transform(const DocNode* node) const override { + if (node == nullptr) return {}; + return split_text(node->get_text_view(), get_node_metadata_size(node)); } - std::vector split_text(const std::string& text, int metadata_size) { - if (text.empty()) return {""}; - const int effective_chunk_size = _chunk_size - metadata_size; + std::vector split_text(const std::string_view& view, int metadata_size) const { + if (view.empty()) return {}; + int effective_chunk_size = _chunk_size - metadata_size; if (effective_chunk_size <= 0) { throw std::runtime_error( - "Metadata length is longer than chunk size."); + "Metadata length (" + std::to_string(metadata_size) + + ") is longer than chunk size (" + std::to_string(_chunk_size) + + "). Consider increasing the chunk size or decreasing the size of your metadata to avoid this."); + } + else if (effective_chunk_size < 50) { + throw std::runtime_error( + "Metadata length (" + std::to_string(metadata_size) + ") is close to chunk size (" + + std::to_string(_chunk_size) + "). Resulting chunks are less than 50 tokens. " + + "Consider increasing the chunk size or decreasing the size of " + + "your metadata to avoid this."); } - auto splits = _split(text, effective_chunk_size); + auto splits = split_recursive(view, effective_chunk_size); return _merge(splits, effective_chunk_size); } - TransformResult transform(DocNode* node, const TransformKwargs& /*kwargs*/) override { - if (node == nullptr) return {}; - auto chunks = split_text(node->get_text(), _get_metadata_size(node)); - TransformResult out; - out.reserve(chunks.size()); - for (auto& chunk : chunks) out.emplace_back(std::move(chunk)); - return out; - } + virtual void set_split_fns( const std::vector& /*split_fns*/, @@ -196,29 +106,29 @@ class _TextSplitterBase : public NodeTransform { virtual void clear_split_fns() {} protected: - virtual std::vector<_Split> _split(const std::string& text, int chunk_size) { - const int token_size = _token_size(text); - if (token_size <= chunk_size) return {_Split{text, true, token_size}}; + virtual std::vector split_recursive(const std::string_view& view, int chunk_size) { + int token_amount = get_token_amount(view); + if (token_amount <= chunk_size) return {ChunkUnit{view, true, token_amount}}; - auto [text_splits, is_sentence] = _get_splits_by_fns(text); - std::vector<_Split> results; + auto [text_splits, is_sentence] = _get_splits_by_fns(view); + std::vector results; for (const auto& segment : text_splits) { - const int seg_token_size = _token_size(segment); + const int seg_token_size = get_token_amount(segment); if (seg_token_size <= chunk_size) { - results.push_back(_Split{segment, is_sentence, seg_token_size}); + results.push_back(ChunkUnit{segment, is_sentence, seg_token_size}); } else { - auto sub_results = _split(segment, chunk_size); + auto sub_results = split_recursive(segment, chunk_size); results.insert(results.end(), sub_results.begin(), sub_results.end()); } } return results; } - virtual std::vector _merge(std::vector<_Split> splits, int chunk_size) { + virtual std::vector _merge(std::vector splits, int chunk_size) { if (splits.empty()) return {}; if (splits.size() == 1) return {splits.front().text}; - _Split end_split = splits.back(); + ChunkUnit end_split = splits.back(); if (end_split.token_size == chunk_size && _overlap > 0) { splits.pop_back(); auto text_tokens = encode(end_split.text); @@ -227,19 +137,19 @@ class _TextSplitterBase : public NodeTransform { std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); std::string p_text = decode(p_tokens); std::string n_text = decode(n_tokens); - splits.push_back(_Split{p_text, end_split.is_sentence, _token_size(p_text)}); - splits.push_back(_Split{n_text, end_split.is_sentence, _token_size(n_text)}); + splits.push_back(ChunkUnit{p_text, end_split.is_sentence, get_token_amount(p_text)}); + splits.push_back(ChunkUnit{n_text, end_split.is_sentence, get_token_amount(n_text)}); end_split = splits.back(); } std::vector result; for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { - _Split start_split = splits[static_cast(idx)]; + ChunkUnit start_split = splits[static_cast(idx)]; if (start_split.token_size <= _overlap && end_split.token_size <= chunk_size - _overlap) { const bool is_sentence = start_split.is_sentence && end_split.is_sentence; const int token_size = start_split.token_size + end_split.token_size; - end_split = _Split{start_split.text + end_split.text, is_sentence, token_size}; + end_split = ChunkUnit{start_split.text + end_split.text, is_sentence, token_size}; continue; } @@ -254,7 +164,7 @@ class _TextSplitterBase : public NodeTransform { std::vector overlap_tokens( start_tokens.end() - overlap_len, start_tokens.end()); std::string overlap_text = decode(overlap_tokens); - end_split = _Split{overlap_text + end_split.text, end_split.is_sentence, + end_split = ChunkUnit{overlap_text + end_split.text, end_split.is_sentence, end_split.token_size + overlap_len}; } @@ -284,24 +194,14 @@ class _TextSplitterBase : public NodeTransform { return {split_to_chars(text), false}; } - int _get_metadata_size(const DocNode* node) const { + int get_node_metadata_size(const DocNode* node) const { return std::max( - _token_size(node->get_metadata_string(MetadataMode::EMBED)), - _token_size(node->get_metadata_string(MetadataMode::LLM))); - } - - int _token_size(const std::string& text) const { - return static_cast(encode(text).size()); + get_token_amount(node->get_metadata_string(MetadataMode::EMBED)), + get_token_amount(node->get_metadata_string(MetadataMode::LLM))); } - std::vector encode(const std::string& text) const { - if (!_tokenizer) throw std::runtime_error("Tokenizer not initialized."); - return _tokenizer->encode(text); - } - - std::string decode(const std::vector& token_ids) const { - if (!_tokenizer) throw std::runtime_error("Tokenizer not initialized."); - return _tokenizer->decode(token_ids); + int get_token_amount(const std::string_view& view) const { + return static_cast(_tokenizer->encode(view).size()); } static std::vector regex_find_all( @@ -348,45 +248,9 @@ class _TextSplitterBase : public NodeTransform { } protected: - int _chunk_size = 1024; - int _overlap = 200; + int _chunk_size; + int _overlap; std::shared_ptr _tokenizer; }; -class _TokenTextSplitter : public _TextSplitterBase { -public: - explicit _TokenTextSplitter( - std::optional chunk_size = std::nullopt, - std::optional overlap = std::nullopt, - std::optional num_workers = std::nullopt, - std::optional sentencepiece_model = std::nullopt) - : _TextSplitterBase(chunk_size, overlap, num_workers, sentencepiece_model) {} - -protected: - std::vector<_Split> _split(const std::string& text, int chunk_size) override { - const int token_size = _token_size(text); - if (token_size <= chunk_size) return {_Split{text, true, token_size}}; - - std::vector<_Split> results; - auto tokens = encode(text); - size_t start_idx = 0; - size_t end_idx = std::min(start_idx + static_cast(chunk_size), tokens.size()); - while (start_idx < tokens.size()) { - std::vector chunk_tokens(tokens.begin() + start_idx, tokens.begin() + end_idx); - results.push_back(_Split{decode(chunk_tokens), true, static_cast(chunk_tokens.size())}); - if (end_idx >= tokens.size()) break; - start_idx = std::min(start_idx + static_cast(chunk_size - _overlap), tokens.size()); - end_idx = std::min(start_idx + static_cast(chunk_size), tokens.size()); - } - return results; - } - - std::vector _merge(std::vector<_Split> splits, int /*chunk_size*/) override { - std::vector out; - out.reserve(splits.size()); - for (auto& split : splits) out.emplace_back(std::move(split.text)); - return out; - } -}; - } // namespace lazyllm diff --git a/csrc/core/include/tokenizer.hpp b/csrc/core/include/tokenizer.hpp new file mode 100644 index 000000000..93bf36fa0 --- /dev/null +++ b/csrc/core/include/tokenizer.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include + +#include + +class Tokenizer { +public: + virtual ~Tokenizer() = default; + virtual std::vector encode(const std::string_view& view) const = 0; + virtual std::string decode(const std::vector& token_ids) const = 0; +}; + +class SentencePieceTokenizer final : public Tokenizer { +public: + SentencePieceTokenizer() = delete; + explicit SentencePieceTokenizer(const std::string& model_path) { + auto status = _processor.Load(model_path); + if (!status.ok()) + throw std::runtime_error("Failed to load sentencepiece model: " + model_path);; + } + + std::vector encode(const std::string_view& view) const override { + std::vector ids; + auto status = _processor.Encode(view, &ids); + if (!status.ok()) throw std::runtime_error(status.ToString()); + return ids; + } + + std::string decode(const std::vector& token_ids) const override { + std::string text; + auto status = _processor.Decode(token_ids, &text); + if (!status.ok()) throw std::runtime_error(status.ToString()); + return text; + } + +private: + sentencepiece::SentencePieceProcessor _processor; +}; diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp new file mode 100644 index 000000000..a26c7b417 --- /dev/null +++ b/csrc/core/src/text_splitter_base.cpp @@ -0,0 +1,3 @@ +#include "text_splitter_base.hpp" + + From 615b7b0e630493dc0c68bbebf9d0da86cca5908f Mon Sep 17 00:00:00 2001 From: yzh Date: Sat, 7 Feb 2026 13:44:22 +0800 Subject: [PATCH 18/40] Module --- csrc/cmake/third_party.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cmake/third_party.cmake b/csrc/cmake/third_party.cmake index c5dd0cfe9..3aa8171a2 100644 --- a/csrc/cmake/third_party.cmake +++ b/csrc/cmake/third_party.cmake @@ -1,6 +1,6 @@ include(FetchContent) -find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED) find_package(pybind11 CONFIG REQUIRED) find_package(xxHash QUIET) From 0b193c839e5ddc604f2fc241aa36e56e3757df3e Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 10 Feb 2026 14:27:44 +0800 Subject: [PATCH 19/40] map_params --- csrc/CMakeLists.txt | 3 +- csrc/binding/export_text_splitter_base.cpp | 14 +- csrc/cmake/third_party.cmake | 24 ++- csrc/core/include/map_params.hpp | 65 +++++++ csrc/core/include/text_splitter_base.hpp | 200 ++++++++++----------- csrc/core/include/tokenizer.hpp | 55 ++++-- 6 files changed, 210 insertions(+), 151 deletions(-) create mode 100644 csrc/core/include/map_params.hpp diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index c873210dd..11938c9b7 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,8 +18,7 @@ file(GLOB_RECURSE LAZYLLM_CORE_SOURCES CONFIGURE_DEPENDS add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/include) target_link_libraries(lazyllm_core PUBLIC xxhash) -target_link_libraries(lazyllm_core PUBLIC sentencepiece) -target_include_directories(lazyllm_core PUBLIC ${sentencepiece_SOURCE_DIR}/src) +target_link_libraries(lazyllm_core PUBLIC tiktoken) # Config lazyllm_adaptor lib which maintains callback invocations. file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp index b1cb9d18a..5010e3bbb 100644 --- a/csrc/binding/export_text_splitter_base.cpp +++ b/csrc/binding/export_text_splitter_base.cpp @@ -1,7 +1,7 @@ #include "lazyllm.hpp" #include "adaptor_base_wrapper.hpp" -#include "text_spliter_base.hpp" +#include "text_splitter_base.hpp" #include #include @@ -12,8 +12,8 @@ class PyTokenizer final : public lazyllm::Tokenizer, public lazyllm::AdaptorBase public: explicit PyTokenizer(const py::object& obj) : AdaptorBaseWrapper(obj) {} - std::vector encode(const std::string& text) const override { - auto result = call("encode", {{"text", text}}); + std::vector encode(const std::string_view& text) const override { + auto result = call("encode", {{"text", std::string(text)}}); return std::any_cast>(result); } @@ -57,12 +57,12 @@ void exportTextSpliterBase(py::module& m) { py::arg("chunk_size") = py::none(), py::arg("overlap") = py::none(), py::arg("num_workers") = py::none(), - py::arg("sentencepiece_model") = py::none() + py::arg("encoding_name") = py::none() ) .def("split_text", &lazyllm::TextSplitterBase::split_text, py::arg("text"), py::arg("metadata_size")) - .def("from_sentencepiece_model", &lazyllm::TextSplitterBase::from_sentencepiece_model, - py::arg("model_path"), py::return_value_policy::reference) + .def("from_tiktoken_encoding", &lazyllm::TextSplitterBase::from_tiktoken_encoding, + py::arg("encoding_name"), py::return_value_policy::reference) .def("from_tokenizer", [](lazyllm::TextSplitterBase& self, py::object tokenizer) -> lazyllm::TextSplitterBase& { auto adaptor = std::make_shared(tokenizer); @@ -100,7 +100,7 @@ void exportTextSpliterBase(py::module& m) { py::arg("chunk_size") = py::none(), py::arg("overlap") = py::none(), py::arg("num_workers") = py::none(), - py::arg("sentencepiece_model") = py::none() + py::arg("encoding_name") = py::none() ); m.def("split_text_keep_separator", &lazyllm::split_text_keep_separator, diff --git a/csrc/cmake/third_party.cmake b/csrc/cmake/third_party.cmake index 3aa8171a2..3246a8645 100644 --- a/csrc/cmake/third_party.cmake +++ b/csrc/cmake/third_party.cmake @@ -1,6 +1,6 @@ include(FetchContent) -find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED) +find_package(Python3 COMPONENTS Interpreter Development Development.Module REQUIRED) find_package(pybind11 CONFIG REQUIRED) find_package(xxHash QUIET) @@ -14,19 +14,15 @@ if (NOT TARGET xxhash) add_subdirectory(${xxhash_SOURCE_DIR}/cmake_unofficial ${xxhash_BINARY_DIR}) endif() -find_package(sentencepiece QUIET) -if (NOT TARGET sentencepiece AND NOT TARGET sentencepiece::sentencepiece AND NOT TARGET sentencepiece-static) +find_package(cpp_tiktoken QUIET) +if (NOT TARGET cpp_tiktoken) + # We only need cpp_tiktoken for in-tree usage; avoid exporting/installing it. + set(CPP_TIKTOKEN_INSTALL OFF CACHE BOOL "" FORCE) + set(CPP_TIKTOKEN_TESTING OFF CACHE BOOL "" FORCE) FetchContent_Declare( - sentencepiece - GIT_REPOSITORY https://github.com/google/sentencepiece.git - GIT_TAG v0.2.0 + cpp_tiktoken + GIT_REPOSITORY https://github.com/gh-markt/cpp-tiktoken.git + GIT_TAG master ) - FetchContent_MakeAvailable(sentencepiece) -endif() -if (TARGET sentencepiece::sentencepiece) - add_library(sentencepiece ALIAS sentencepiece::sentencepiece) -elseif (TARGET sentencepiece) - add_library(sentencepiece ALIAS sentencepiece) -elseif (TARGET sentencepiece-static) - add_library(sentencepiece ALIAS sentencepiece-static) + FetchContent_MakeAvailable(cpp_tiktoken) endif() diff --git a/csrc/core/include/map_params.hpp b/csrc/core/include/map_params.hpp new file mode 100644 index 000000000..ed0c7bbc6 --- /dev/null +++ b/csrc/core/include/map_params.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace lazyllm { + +class MapParams { +public: + using MapType = std::unordered_map; + + template + T get_param_value( + const std::string_view& param_name, + const std::optional& value, + const T& default_value) const + { + if (value.has_value()) return *value; + std::lock_guard guard(_lock); + auto it = _params.find(param_name); + if (it != _params.end()) return std::any_cast(it->second); + return default_value; + } + + template + void set_default(const std::string_view& param_name, T value) { + std::lock_guard guard(_lock); + _params[param_name] = std::any(value); + } + + void set_default(const MapType& updates) { + std::lock_guard guard(_lock); + for (const auto& entry : updates) { + _params[entry.first] = entry.second; + } + } + + MapType get_default() const { + std::lock_guard guard(_lock); + return _params; + } + + template + std::optional get_default(const std::string& param_name) const { + std::lock_guard guard(_lock); + auto it = _params.find(param_name); + if (it == _params.end()) return std::nullopt; + return std::any_cast(it->second); + } + + void reset_default() { + std::lock_guard guard(_lock); + _params.clear(); + } + +private: + mutable std::mutex _lock; + MapType _params; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 7a178b792..80c11878a 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -6,68 +6,46 @@ #include #include #include -#include #include #include #include #include -#include #include #include #include #include "doc_node.hpp" +#include "map_params.hpp" #include "node_transform.hpp" #include "tokenizer.hpp" namespace lazyllm { -struct ChunkUnit { +struct SplitUnit { std::string_view view; bool is_sentence = false; int token_size = 0; }; -inline std::vector split_text_keep_separator( - const std::string& text, - const std::string& separator) -{ - if (separator.empty()) return text.empty() ? std::vector() : std::vector{text}; - if (text.find(separator) == std::string::npos) return {text}; - - std::vector result; - size_t start = 0; - const size_t sep_len = separator.size(); - while (start < text.size()) { - const size_t idx = text.find(separator, start); - if (idx == std::string::npos) { - result.emplace_back(text.substr(start)); - break; - } - if (idx == 0) { - start = sep_len; - continue; - } - result.emplace_back(text.substr(start, idx - start + sep_len)); - start = idx + sep_len; - } - return result; -} - class TextSplitterBase : public NodeTransform { public: using SplitFn = std::function(const std::string&)>; - explicit TextSplitterBase(unsigned chunk_size = 1024, unsigned overlap = 200, unsigned worker_num = 0, - const std::string& model_path = "") - : NodeTransform(worker_num), _chunk_size(chunk_size), _overlap(overlap) + explicit TextSplitterBase( + std::optional chunk_size, + std::optional overlap, + std::optional worker_num, + const std::string& encoding_name = "gpt2") + : NodeTransform(_default_params.get_param_value("worker_num", worker_num, 0)), + _chunk_size(_default_params.get_param_value("chunk_size", chunk_size, 1024)), + _overlap(_default_params.get_param_value("overlap", overlap, 200)) { if (_overlap > _chunk_size) throw std::runtime_error("'overlap' should be less than 'chunk_size'."); if (_chunk_size == 0) throw std::runtime_error("'chunk_size' should > 0"); - _tokenizer = std::make_shared(model_path); + _tokenizer = std::make_shared(encoding_name); } std::vector transform(const DocNode* node) const override { @@ -95,8 +73,6 @@ class TextSplitterBase : public NodeTransform { return _merge(splits, effective_chunk_size); } - - virtual void set_split_fns( const std::vector& /*split_fns*/, const std::optional>& /*sub_split_fns*/ = std::nullopt) {} @@ -106,16 +82,16 @@ class TextSplitterBase : public NodeTransform { virtual void clear_split_fns() {} protected: - virtual std::vector split_recursive(const std::string_view& view, int chunk_size) { - int token_amount = get_token_amount(view); - if (token_amount <= chunk_size) return {ChunkUnit{view, true, token_amount}}; + virtual std::vector split_recursive(const std::string_view& view, int chunk_size) const { + int token_size = get_token_size(view); + if (token_size <= chunk_size) return {SplitUnit{view, true, token_size}}; auto [text_splits, is_sentence] = _get_splits_by_fns(view); - std::vector results; + std::vector results; for (const auto& segment : text_splits) { - const int seg_token_size = get_token_amount(segment); + const int seg_token_size = get_token_size(segment); if (seg_token_size <= chunk_size) { - results.push_back(ChunkUnit{segment, is_sentence, seg_token_size}); + results.push_back(SplitUnit{segment, is_sentence, seg_token_size}); } else { auto sub_results = split_recursive(segment, chunk_size); results.insert(results.end(), sub_results.begin(), sub_results.end()); @@ -124,60 +100,34 @@ class TextSplitterBase : public NodeTransform { return results; } - virtual std::vector _merge(std::vector splits, int chunk_size) { - if (splits.empty()) return {}; - if (splits.size() == 1) return {splits.front().text}; - - ChunkUnit end_split = splits.back(); - if (end_split.token_size == chunk_size && _overlap > 0) { - splits.pop_back(); - auto text_tokens = encode(end_split.text); - const size_t half = text_tokens.size() / 2; - std::vector p_tokens(text_tokens.begin(), text_tokens.begin() + half); - std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); - std::string p_text = decode(p_tokens); - std::string n_text = decode(n_tokens); - splits.push_back(ChunkUnit{p_text, end_split.is_sentence, get_token_amount(p_text)}); - splits.push_back(ChunkUnit{n_text, end_split.is_sentence, get_token_amount(n_text)}); - end_split = splits.back(); - } + std::vector split_text_keep_separator( + const std::string& text, + const std::string& separator) + { + if (separator.empty()) return text.empty() ? std::vector() : std::vector{text}; + if (text.find(separator) == std::string::npos) return {text}; std::vector result; - for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { - ChunkUnit start_split = splits[static_cast(idx)]; - if (start_split.token_size <= _overlap && - end_split.token_size <= chunk_size - _overlap) { - const bool is_sentence = start_split.is_sentence && end_split.is_sentence; - const int token_size = start_split.token_size + end_split.token_size; - end_split = ChunkUnit{start_split.text + end_split.text, is_sentence, token_size}; - continue; - } - - if (end_split.token_size > chunk_size) { - throw std::runtime_error("split token size is greater than chunk size."); + size_t start = 0; + const size_t sep_len = separator.size(); + while (start < text.size()) { + const size_t idx = text.find(separator, start); + if (idx == std::string::npos) { + result.emplace_back(text.substr(start)); + break; } - - const int remaining_space = chunk_size - end_split.token_size; - const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); - if (overlap_len > 0) { - auto start_tokens = encode(start_split.text); - std::vector overlap_tokens( - start_tokens.end() - overlap_len, start_tokens.end()); - std::string overlap_text = decode(overlap_tokens); - end_split = ChunkUnit{overlap_text + end_split.text, end_split.is_sentence, - end_split.token_size + overlap_len}; + if (idx == 0) { + start = sep_len; + continue; } - - result.insert(result.begin(), end_split.text); - end_split = start_split; + result.emplace_back(text.substr(start, idx - start + sep_len)); + start = idx + sep_len; } - - result.insert(result.begin(), end_split.text); return result; } - virtual std::pair, bool> _get_splits_by_fns( - const std::string& text) const + + virtual std::pair, bool> _get_splits_by_fns(const std::string& text) const { auto splits = split_text_keep_separator(text, "\n\n\n"); if (splits.size() > 1) return {splits, true}; @@ -196,11 +146,11 @@ class TextSplitterBase : public NodeTransform { int get_node_metadata_size(const DocNode* node) const { return std::max( - get_token_amount(node->get_metadata_string(MetadataMode::EMBED)), - get_token_amount(node->get_metadata_string(MetadataMode::LLM))); + get_token_size(node->get_metadata_string(MetadataMode::EMBED)), + get_token_size(node->get_metadata_string(MetadataMode::LLM))); } - int get_token_amount(const std::string_view& view) const { + int get_token_size(const std::string_view& view) const { return static_cast(_tokenizer->encode(view).size()); } @@ -224,29 +174,61 @@ class TextSplitterBase : public NodeTransform { return out; } - static std::unordered_map& default_params() { - static std::unordered_map params; - return params; - } + std::vector _merge(std::vector splits, int chunk_size) { + if (splits.empty()) return {}; + if (splits.size() == 1) return {splits.front().text}; - static std::recursive_mutex& default_params_lock() { - static std::recursive_mutex lock; - return lock; - } + SplitUnit end_split = splits.back(); + if (end_split.token_size == chunk_size && _overlap > 0) { + splits.pop_back(); + auto text_tokens = encode(end_split.text); + const size_t half = text_tokens.size() / 2; + std::vector p_tokens(text_tokens.begin(), text_tokens.begin() + half); + std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); + std::string p_text = decode(p_tokens); + std::string n_text = decode(n_tokens); + splits.push_back(SplitUnit{p_text, end_split.is_sentence, get_token_size(p_text)}); + splits.push_back(SplitUnit{n_text, end_split.is_sentence, get_token_size(n_text)}); + end_split = splits.back(); + } - static int get_param_value( - const std::string& param_name, - const std::optional& value, - int default_value) - { - if (value.has_value()) return *value; - std::lock_guard guard(default_params_lock()); - auto& params = default_params(); - auto it = params.find(param_name); - if (it != params.end()) return it->second; - return default_value; + std::vector result; + for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { + SplitUnit start_split = splits[static_cast(idx)]; + if (start_split.token_size <= _overlap && + end_split.token_size <= chunk_size - _overlap) { + const bool is_sentence = start_split.is_sentence && end_split.is_sentence; + const int token_size = start_split.token_size + end_split.token_size; + end_split = SplitUnit{start_split.text + end_split.text, is_sentence, token_size}; + continue; + } + + if (end_split.token_size > chunk_size) { + throw std::runtime_error("split token size is greater than chunk size."); + } + + const int remaining_space = chunk_size - end_split.token_size; + const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); + if (overlap_len > 0) { + auto start_tokens = encode(start_split.text); + std::vector overlap_tokens( + start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = decode(overlap_tokens); + end_split = SplitUnit{overlap_text + end_split.text, end_split.is_sentence, + end_split.token_size + overlap_len}; + } + + result.insert(result.begin(), end_split.text); + end_split = start_split; + } + + result.insert(result.begin(), end_split.text); + return result; } +private: + static MapParams _default_params; + protected: int _chunk_size; int _overlap; diff --git a/csrc/core/include/tokenizer.hpp b/csrc/core/include/tokenizer.hpp index 93bf36fa0..dd667b178 100644 --- a/csrc/core/include/tokenizer.hpp +++ b/csrc/core/include/tokenizer.hpp @@ -1,11 +1,15 @@ #pragma once -#include -#include +#include +#include #include #include +#include +#include +#include -#include +#include +#include class Tokenizer { public: @@ -14,29 +18,42 @@ class Tokenizer { virtual std::string decode(const std::vector& token_ids) const = 0; }; -class SentencePieceTokenizer final : public Tokenizer { +class TiktokenTokenizer final : public Tokenizer { public: - SentencePieceTokenizer() = delete; - explicit SentencePieceTokenizer(const std::string& model_path) { - auto status = _processor.Load(model_path); - if (!status.ok()) - throw std::runtime_error("Failed to load sentencepiece model: " + model_path);; - } + TiktokenTokenizer() = delete; + explicit TiktokenTokenizer(LanguageModel model) + : _encoding(GptEncoding::get_encoding(model)) {} + + explicit TiktokenTokenizer(std::string_view encoding_name) + : TiktokenTokenizer(parse_tiktoken_model(encoding_name)) {} + + ~TiktokenTokenizer() override = default; std::vector encode(const std::string_view& view) const override { - std::vector ids; - auto status = _processor.Encode(view, &ids); - if (!status.ok()) throw std::runtime_error(status.ToString()); - return ids; + return _encoding->encode(std::string(view)); // TODO refactor to string_view } std::string decode(const std::vector& token_ids) const override { - std::string text; - auto status = _processor.Decode(token_ids, &text); - if (!status.ok()) throw std::runtime_error(status.ToString()); - return text; + return _encoding->decode(token_ids); + } + +private: + static LanguageModel parse_tiktoken_model(std::string_view name) { + if (name.empty()) return LanguageModel::R50K_BASE; + + if (name == "gpt2" || name == "r50k_base" || name == "r50k") return LanguageModel::R50K_BASE; + if (name == "p50k_base" || name == "p50k") return LanguageModel::P50K_BASE; + if (name == "p50k_edit") return LanguageModel::P50K_EDIT; + if (name == "cl100k_base" || name == "cl100k") return LanguageModel::CL100K_BASE; + if (name == "o200k_base" || name == "o200k") return LanguageModel::O200K_BASE; + if (name == "qwen_base" || name == "qwen") return LanguageModel::QWEN_BASE; + + throw std::runtime_error( + "Unknown tiktoken encoding/model name: " + std::string(name) + + ". Expected one of: gpt2, r50k_base, p50k_base, p50k_edit, cl100k_base, o200k_base, qwen_base." + + "(Case sensitive)"); } private: - sentencepiece::SentencePieceProcessor _processor; + std::shared_ptr _encoding; }; From 0d88ea65443559899d91e15aa153c76d86036908 Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 10 Feb 2026 18:21:36 +0800 Subject: [PATCH 20/40] save --- csrc/core/include/text_splitter_base.hpp | 165 +++-------------------- csrc/core/src/text_splitter_base.cpp | 128 ++++++++++++++++++ 2 files changed, 146 insertions(+), 147 deletions(-) diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 80c11878a..8da0e3574 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -40,10 +40,8 @@ class TextSplitterBase : public NodeTransform { _chunk_size(_default_params.get_param_value("chunk_size", chunk_size, 1024)), _overlap(_default_params.get_param_value("overlap", overlap, 200)) { - if (_overlap > _chunk_size) - throw std::runtime_error("'overlap' should be less than 'chunk_size'."); - if (_chunk_size == 0) - throw std::runtime_error("'chunk_size' should > 0"); + if (_overlap > _chunk_size) throw std::runtime_error("'overlap' should be less than 'chunk_size'."); + if (_chunk_size == 0) throw std::runtime_error("'chunk_size' should > 0"); _tokenizer = std::make_shared(encoding_name); } @@ -53,96 +51,23 @@ class TextSplitterBase : public NodeTransform { return split_text(node->get_text_view(), get_node_metadata_size(node)); } - std::vector split_text(const std::string_view& view, int metadata_size) const { - if (view.empty()) return {}; - int effective_chunk_size = _chunk_size - metadata_size; - if (effective_chunk_size <= 0) { - throw std::runtime_error( - "Metadata length (" + std::to_string(metadata_size) + - ") is longer than chunk size (" + std::to_string(_chunk_size) + - "). Consider increasing the chunk size or decreasing the size of your metadata to avoid this."); - } - else if (effective_chunk_size < 50) { - throw std::runtime_error( - "Metadata length (" + std::to_string(metadata_size) + ") is close to chunk size (" + - std::to_string(_chunk_size) + "). Resulting chunks are less than 50 tokens. " + - "Consider increasing the chunk size or decreasing the size of " + - "your metadata to avoid this."); - } - auto splits = split_recursive(view, effective_chunk_size); - return _merge(splits, effective_chunk_size); - } - - virtual void set_split_fns( - const std::vector& /*split_fns*/, - const std::optional>& /*sub_split_fns*/ = std::nullopt) {} + std::vector split_text(const std::string_view& view, int metadata_size) const; - virtual void add_split_fn(const SplitFn& /*split_fn*/, const std::optional& /*index*/ = std::nullopt) {} - - virtual void clear_split_fns() {} + virtual void set_split_functions( + const std::vector&, + const std::optional>& = std::nullopt) {} + virtual void add_split_function(const SplitFn&, const std::optional& = std::nullopt) {} + virtual void clear_split_functions() {} protected: - virtual std::vector split_recursive(const std::string_view& view, int chunk_size) const { - int token_size = get_token_size(view); - if (token_size <= chunk_size) return {SplitUnit{view, true, token_size}}; - - auto [text_splits, is_sentence] = _get_splits_by_fns(view); - std::vector results; - for (const auto& segment : text_splits) { - const int seg_token_size = get_token_size(segment); - if (seg_token_size <= chunk_size) { - results.push_back(SplitUnit{segment, is_sentence, seg_token_size}); - } else { - auto sub_results = split_recursive(segment, chunk_size); - results.insert(results.end(), sub_results.begin(), sub_results.end()); - } - } - return results; - } - - std::vector split_text_keep_separator( - const std::string& text, - const std::string& separator) - { - if (separator.empty()) return text.empty() ? std::vector() : std::vector{text}; - if (text.find(separator) == std::string::npos) return {text}; - - std::vector result; - size_t start = 0; - const size_t sep_len = separator.size(); - while (start < text.size()) { - const size_t idx = text.find(separator, start); - if (idx == std::string::npos) { - result.emplace_back(text.substr(start)); - break; - } - if (idx == 0) { - start = sep_len; - continue; - } - result.emplace_back(text.substr(start, idx - start + sep_len)); - start = idx + sep_len; - } - return result; - } - - - virtual std::pair, bool> _get_splits_by_fns(const std::string& text) const - { - auto splits = split_text_keep_separator(text, "\n\n\n"); - if (splits.size() > 1) return {splits, true}; + virtual std::vector split_recursive(const std::string_view& view, const int chunk_size) const; + virtual std::vector _merge(std::vector splits, int chunk_size); - splits = regex_find_all(text, R"([^.!?。?!]+[.!?。?!]?)"); - if (splits.size() > 1) return {splits, true}; - - splits = regex_find_all(text, R"([^,.;。?!]+[,.;。?!]?)"); - if (splits.size() > 1) return {splits, false}; - - splits = split_text_keep_separator(text, " "); - if (splits.size() > 1) return {splits, false}; - - return {split_to_chars(text), false}; - } +private: + std::tuple, bool> split_by_functions(const std::string& text) const; + std::vector split_text_while_keeping_separator( + const std::string_view& text, + const std::string_view& separator) const; int get_node_metadata_size(const DocNode* node) const { return std::max( @@ -154,9 +79,7 @@ class TextSplitterBase : public NodeTransform { return static_cast(_tokenizer->encode(view).size()); } - static std::vector regex_find_all( - const std::string& text, const std::string& pattern) - { + static std::vector regex_find_all(const std::string_view& text, const std::string_view& pattern) { std::regex re(pattern); std::vector out; for (auto it = std::sregex_iterator(text.begin(), text.end(), re); @@ -167,65 +90,13 @@ class TextSplitterBase : public NodeTransform { return out; } - static std::vector split_to_chars(const std::string& text) { - std::vector out; + static std::vector split_to_chars(const std::string_view& text) { + std::vector out; out.reserve(text.size()); for (char c : text) out.emplace_back(1, c); return out; } - std::vector _merge(std::vector splits, int chunk_size) { - if (splits.empty()) return {}; - if (splits.size() == 1) return {splits.front().text}; - - SplitUnit end_split = splits.back(); - if (end_split.token_size == chunk_size && _overlap > 0) { - splits.pop_back(); - auto text_tokens = encode(end_split.text); - const size_t half = text_tokens.size() / 2; - std::vector p_tokens(text_tokens.begin(), text_tokens.begin() + half); - std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); - std::string p_text = decode(p_tokens); - std::string n_text = decode(n_tokens); - splits.push_back(SplitUnit{p_text, end_split.is_sentence, get_token_size(p_text)}); - splits.push_back(SplitUnit{n_text, end_split.is_sentence, get_token_size(n_text)}); - end_split = splits.back(); - } - - std::vector result; - for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { - SplitUnit start_split = splits[static_cast(idx)]; - if (start_split.token_size <= _overlap && - end_split.token_size <= chunk_size - _overlap) { - const bool is_sentence = start_split.is_sentence && end_split.is_sentence; - const int token_size = start_split.token_size + end_split.token_size; - end_split = SplitUnit{start_split.text + end_split.text, is_sentence, token_size}; - continue; - } - - if (end_split.token_size > chunk_size) { - throw std::runtime_error("split token size is greater than chunk size."); - } - - const int remaining_space = chunk_size - end_split.token_size; - const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); - if (overlap_len > 0) { - auto start_tokens = encode(start_split.text); - std::vector overlap_tokens( - start_tokens.end() - overlap_len, start_tokens.end()); - std::string overlap_text = decode(overlap_tokens); - end_split = SplitUnit{overlap_text + end_split.text, end_split.is_sentence, - end_split.token_size + overlap_len}; - } - - result.insert(result.begin(), end_split.text); - end_split = start_split; - } - - result.insert(result.begin(), end_split.text); - return result; - } - private: static MapParams _default_params; diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index a26c7b417..877949d08 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -1,3 +1,131 @@ #include "text_splitter_base.hpp" +namespace lazyllm { +std::vector TextSplitterBase::split_text(const std::string_view& view, int metadata_size) const { + if (view.empty()) return {}; + int effective_chunk_size = _chunk_size - metadata_size; + if (effective_chunk_size <= 0) { + throw std::runtime_error( + "Metadata length (" + std::to_string(metadata_size) + + ") is longer than chunk size (" + std::to_string(_chunk_size) + + "). Consider increasing the chunk size or decreasing the size of your metadata to avoid this."); + } + else if (effective_chunk_size < 50) { + throw std::runtime_error( + "Metadata length (" + std::to_string(metadata_size) + ") is close to chunk size (" + + std::to_string(_chunk_size) + "). Resulting chunks are less than 50 tokens. " + + "Consider increasing the chunk size or decreasing the size of " + + "your metadata to avoid this."); + } + auto splits = split_recursive(view, effective_chunk_size); + return _merge(splits, effective_chunk_size); +} + +std::vector TextSplitterBase::split_recursive( + const std::string_view& view, const int chunk_size) const +{ + int token_size = get_token_size(view); + if (token_size <= chunk_size) return {SplitUnit{view, true, token_size}}; + + auto [views, is_sentence] = split_by_functions(view); + std::vector splits; + for (const auto& view : views) { + const int seg_token_size = get_token_size(view); + if (seg_token_size <= chunk_size) { + splits.emplace_back(view, is_sentence, seg_token_size); + } else { + auto new_splits = split_recursive(view, chunk_size); + splits.insert(splits.end(), new_splits.begin(), new_splits.end()); + } + } + return splits; +} + +std::tuple, bool> TextSplitterBase::split_by_functions(const std::string& text) const +{ + auto views = split_text_while_keeping_separator(text, "\n\n\n"); + if (views.size() > 1) return {views, true}; + + views = regex_find_all(text, R"([^,.;。?!]+[,.;。?!]?)"); + if (views.size() > 1) return {views, false}; + + views = split_text_while_keeping_separator(text, " "); + if (views.size() > 1) return {views, false}; + + return {split_to_chars(text), false}; +} + +std::vector TextSplitterBase::split_text_while_keeping_separator( + const std::string_view& text, + const std::string_view& separator) const +{ + if (text.empty()) return {}; + else if (separator.empty()) return {text}; + + std::vector result; + size_t start = 0; + const size_t sep_len = separator.size(); + while (start <= text.size()) { + const size_t idx = text.find(separator, start); + const size_t end = (idx == std::string_view::npos) ? text.size() : idx; + if (end > start) result.emplace_back(text.substr(start, end - start)); + if (idx == std::string_view::npos) break; + start = idx + sep_len; + } + return result; +} + +std::vector TextSplitterBase::_merge(std::vector splits, int chunk_size) { + if (splits.empty()) return {}; + if (splits.size() == 1) return {splits.front().text}; + + SplitUnit end_split = splits.back(); + if (end_split.token_size == chunk_size && _overlap > 0) { + splits.pop_back(); + auto text_tokens = encode(end_split.text); + const size_t half = text_tokens.size() / 2; + std::vector p_tokens(text_tokens.begin(), text_tokens.begin() + half); + std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); + std::string p_text = decode(p_tokens); + std::string n_text = decode(n_tokens); + splits.push_back(SplitUnit{p_text, end_split.is_sentence, get_token_size(p_text)}); + splits.push_back(SplitUnit{n_text, end_split.is_sentence, get_token_size(n_text)}); + end_split = splits.back(); + } + + std::vector result; + for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { + SplitUnit start_split = splits[static_cast(idx)]; + if (start_split.token_size <= _overlap && + end_split.token_size <= chunk_size - _overlap) { + const bool is_sentence = start_split.is_sentence && end_split.is_sentence; + const int token_size = start_split.token_size + end_split.token_size; + end_split = SplitUnit{start_split.text + end_split.text, is_sentence, token_size}; + continue; + } + + if (end_split.token_size > chunk_size) { + throw std::runtime_error("split token size is greater than chunk size."); + } + + const int remaining_space = chunk_size - end_split.token_size; + const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); + if (overlap_len > 0) { + auto start_tokens = encode(start_split.text); + std::vector overlap_tokens( + start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = decode(overlap_tokens); + end_split = SplitUnit{overlap_text + end_split.text, end_split.is_sentence, + end_split.token_size + overlap_len}; + } + + result.insert(result.begin(), end_split.text); + end_split = start_split; + } + + result.insert(result.begin(), end_split.text); + return result; +} + +} From 02cbec4a4cf06612fb0ec089295e0008462210ca Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 10 Feb 2026 19:27:16 +0800 Subject: [PATCH 21/40] Integrate utf8proc to split text to readable chars. --- csrc/CMakeLists.txt | 1 + csrc/cmake/third_party.cmake | 12 +++++ csrc/core/include/text_splitter_base.hpp | 19 +------- csrc/core/src/text_splitter_base.cpp | 58 ++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 17 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 11938c9b7..8756f2758 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(lazyllm_core STATIC ${LAZYLLM_CORE_SOURCES}) target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/include) target_link_libraries(lazyllm_core PUBLIC xxhash) target_link_libraries(lazyllm_core PUBLIC tiktoken) +target_link_libraries(lazyllm_core PUBLIC utf8proc) # Config lazyllm_adaptor lib which maintains callback invocations. file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS diff --git a/csrc/cmake/third_party.cmake b/csrc/cmake/third_party.cmake index 3246a8645..794f2d51e 100644 --- a/csrc/cmake/third_party.cmake +++ b/csrc/cmake/third_party.cmake @@ -26,3 +26,15 @@ if (NOT TARGET cpp_tiktoken) ) FetchContent_MakeAvailable(cpp_tiktoken) endif() + +find_package(utf8proc QUIET) +if (NOT TARGET utf8proc) + # We only need utf8proc for in-tree usage; avoid exporting/installing it. + set(UTF8PROC_INSTALL OFF CACHE BOOL "" FORCE) + FetchContent_Declare( + utf8proc + GIT_REPOSITORY https://github.com/JuliaStrings/utf8proc.git + GIT_TAG v2.9.0 + ) + FetchContent_MakeAvailable(utf8proc) +endif() diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 8da0e3574..f5968300b 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -79,23 +79,8 @@ class TextSplitterBase : public NodeTransform { return static_cast(_tokenizer->encode(view).size()); } - static std::vector regex_find_all(const std::string_view& text, const std::string_view& pattern) { - std::regex re(pattern); - std::vector out; - for (auto it = std::sregex_iterator(text.begin(), text.end(), re); - it != std::sregex_iterator(); ++it) { - out.emplace_back(it->str()); - } - if (out.empty()) out.emplace_back(text); - return out; - } - - static std::vector split_to_chars(const std::string_view& text) { - std::vector out; - out.reserve(text.size()); - for (char c : text) out.emplace_back(1, c); - return out; - } + static std::vector regex_find_all(const std::string_view& text, const std::string_view& pattern); + static std::vector split_to_chars(const std::string_view& text); private: static MapParams _default_params; diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index 877949d08..c17b73835 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -1,5 +1,7 @@ #include "text_splitter_base.hpp" +#include + namespace lazyllm { std::vector TextSplitterBase::split_text(const std::string_view& view, int metadata_size) const { @@ -76,6 +78,62 @@ std::vector TextSplitterBase::split_text_while_keeping_separat return result; } +std::vector TextSplitterBase::regex_find_all( + const std::string_view& text, + const std::string_view& pattern) +{ + std::regex re(pattern); + std::vector out; + for (auto it = std::sregex_iterator(text.begin(), text.end(), re); + it != std::sregex_iterator(); ++it) { + out.emplace_back(it->str()); + } + if (out.empty()) out.emplace_back(text); + return out; +} + +std::vector TextSplitterBase::split_to_chars(const std::string_view& text) { + std::vector out; + if (text.empty()) return out; + out.reserve(text.size()); + + const size_t len = text.size(); + size_t cluster_start = 0; + size_t i = 0; + utf8proc_int32_t prev = -1; + utf8proc_propval_t state = 0; + + while (i < len) { + utf8proc_int32_t codepoint = -1; + const utf8proc_ssize_t n = utf8proc_iterate( + reinterpret_cast(text.data() + i), + static_cast(len - i), + &codepoint); + if (n <= 0) { + if (i > cluster_start) { + out.emplace_back(text.substr(cluster_start, i - cluster_start)); + } + out.emplace_back(text.substr(i, 1)); + i += 1; + cluster_start = i; + prev = -1; + state = 0; + continue; + } + + if (prev >= 0 && utf8proc_grapheme_break_stateful(prev, codepoint, &state)) { + out.emplace_back(text.substr(cluster_start, i - cluster_start)); + cluster_start = i; + } + prev = codepoint; + i += static_cast(n); + } + if (cluster_start < len) { + out.emplace_back(text.substr(cluster_start, len - cluster_start)); + } + return out; +} + std::vector TextSplitterBase::_merge(std::vector splits, int chunk_size) { if (splits.empty()) return {}; if (splits.size() == 1) return {splits.front().text}; From af7e617f89ebb360ff9ce19dd767f277136066de Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 12 Feb 2026 16:59:36 +0800 Subject: [PATCH 22/40] UnicodeProcessor --- csrc/core/include/text_splitter_base.hpp | 6 +- csrc/core/include/unicode_processor.hpp | 31 +++++++++ csrc/core/include/utils.hpp | 1 + csrc/core/src/text_splitter_base.cpp | 68 ++----------------- csrc/core/src/unicode_processor.cpp | 85 ++++++++++++++++++++++++ 5 files changed, 124 insertions(+), 67 deletions(-) create mode 100644 csrc/core/include/unicode_processor.hpp create mode 100644 csrc/core/src/unicode_processor.cpp diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index f5968300b..65843436f 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -64,7 +63,7 @@ class TextSplitterBase : public NodeTransform { virtual std::vector _merge(std::vector splits, int chunk_size); private: - std::tuple, bool> split_by_functions(const std::string& text) const; + std::tuple, bool> split_by_functions(const std::string_view& text) const; std::vector split_text_while_keeping_separator( const std::string_view& text, const std::string_view& separator) const; @@ -79,9 +78,6 @@ class TextSplitterBase : public NodeTransform { return static_cast(_tokenizer->encode(view).size()); } - static std::vector regex_find_all(const std::string_view& text, const std::string_view& pattern); - static std::vector split_to_chars(const std::string_view& text); - private: static MapParams _default_params; diff --git a/csrc/core/include/unicode_processor.hpp b/csrc/core/include/unicode_processor.hpp new file mode 100644 index 000000000..05ef9c268 --- /dev/null +++ b/csrc/core/include/unicode_processor.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace lazyllm { + +class UnicodeProcessor { +public: + UnicodeProcessor(const std::string_view& text) : _text(text), _text_len(text.size()) {} + std::vector split_to_chars() const; + std::vector split_by_punctuation() const; + +private: + using Utf8Visitor = std::function; + + void for_each_utf8_unit(const Utf8Visitor& visitor) const; + static bool is_sentence_punctuation(char32_t codepoint) { + return std::find(kPunctuationCodepoints.begin(), kPunctuationCodepoints.end(), + codepoint) != kPunctuationCodepoints.end(); + } + + static const std::array kPunctuationCodepoints; + std::string_view _text; + size_t _text_len = 0; +}; + +} // namespace lazyllm diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 680955553..919364e48 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace lazyllm { diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index c17b73835..0bb9f6e9b 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -1,6 +1,5 @@ #include "text_splitter_base.hpp" - -#include +#include "unicode_processor.hpp" namespace lazyllm { @@ -44,18 +43,18 @@ std::vector TextSplitterBase::split_recursive( return splits; } -std::tuple, bool> TextSplitterBase::split_by_functions(const std::string& text) const +std::tuple, bool> TextSplitterBase::split_by_functions(const std::string_view& text) const { auto views = split_text_while_keeping_separator(text, "\n\n\n"); if (views.size() > 1) return {views, true}; - views = regex_find_all(text, R"([^,.;。?!]+[,.;。?!]?)"); + views = UnicodeProcessor(text).split_by_punctuation(); if (views.size() > 1) return {views, false}; views = split_text_while_keeping_separator(text, " "); if (views.size() > 1) return {views, false}; - return {split_to_chars(text), false}; + return {UnicodeProcessor(text).split_to_chars(), false}; } std::vector TextSplitterBase::split_text_while_keeping_separator( @@ -71,69 +70,14 @@ std::vector TextSplitterBase::split_text_while_keeping_separat while (start <= text.size()) { const size_t idx = text.find(separator, start); const size_t end = (idx == std::string_view::npos) ? text.size() : idx; - if (end > start) result.emplace_back(text.substr(start, end - start)); + if (end > start) // Drop empty strings + result.emplace_back(text.substr(start, end - start)); if (idx == std::string_view::npos) break; start = idx + sep_len; } return result; } -std::vector TextSplitterBase::regex_find_all( - const std::string_view& text, - const std::string_view& pattern) -{ - std::regex re(pattern); - std::vector out; - for (auto it = std::sregex_iterator(text.begin(), text.end(), re); - it != std::sregex_iterator(); ++it) { - out.emplace_back(it->str()); - } - if (out.empty()) out.emplace_back(text); - return out; -} - -std::vector TextSplitterBase::split_to_chars(const std::string_view& text) { - std::vector out; - if (text.empty()) return out; - out.reserve(text.size()); - - const size_t len = text.size(); - size_t cluster_start = 0; - size_t i = 0; - utf8proc_int32_t prev = -1; - utf8proc_propval_t state = 0; - - while (i < len) { - utf8proc_int32_t codepoint = -1; - const utf8proc_ssize_t n = utf8proc_iterate( - reinterpret_cast(text.data() + i), - static_cast(len - i), - &codepoint); - if (n <= 0) { - if (i > cluster_start) { - out.emplace_back(text.substr(cluster_start, i - cluster_start)); - } - out.emplace_back(text.substr(i, 1)); - i += 1; - cluster_start = i; - prev = -1; - state = 0; - continue; - } - - if (prev >= 0 && utf8proc_grapheme_break_stateful(prev, codepoint, &state)) { - out.emplace_back(text.substr(cluster_start, i - cluster_start)); - cluster_start = i; - } - prev = codepoint; - i += static_cast(n); - } - if (cluster_start < len) { - out.emplace_back(text.substr(cluster_start, len - cluster_start)); - } - return out; -} - std::vector TextSplitterBase::_merge(std::vector splits, int chunk_size) { if (splits.empty()) return {}; if (splits.size() == 1) return {splits.front().text}; diff --git a/csrc/core/src/unicode_processor.cpp b/csrc/core/src/unicode_processor.cpp new file mode 100644 index 000000000..26d80f3b2 --- /dev/null +++ b/csrc/core/src/unicode_processor.cpp @@ -0,0 +1,85 @@ +#include "unicode_processor.hpp" +namespace lazyllm { + +const std::array UnicodeProcessor::kPunctuationCodepoints = { + U',', + U'.', + U';', + U'!', + U'\uFF0C', // , + U'\uFF1B', // ; + U'\u3002', // 。 + U'\uFF1F', // ? + U'\uFF01', // ! +}; + +void UnicodeProcessor::for_each_utf8_unit(const Utf8Visitor& visitor) const { + size_t i = 0; + while (i < _text_len) { + utf8proc_int32_t codepoint = -1; + const utf8proc_ssize_t n = utf8proc_iterate( + reinterpret_cast(_text.data() + i), + static_cast(_text_len - i), + &codepoint); + + if (n <= 0) { + i += 1; + continue; + } + + visitor(i, static_cast(n), codepoint); + i += static_cast(n); + } +} + +std::vector UnicodeProcessor::split_to_chars() const { + std::vector out; + if (_text.empty()) return out; + out.reserve(_text_len); // Grapheme count <= byte length + + size_t cluster_start = std::string_view::npos; + utf8proc_int32_t prev = -1; + utf8proc_int32_t state = 0; + + for_each_utf8_unit([&](size_t offset, size_t byte_len, utf8proc_int32_t codepoint) { + if (cluster_start == std::string_view::npos) { + cluster_start = offset; + } else if (utf8proc_grapheme_break_stateful(prev, codepoint, &state)) { + out.emplace_back(_text.substr(cluster_start, offset - cluster_start)); + cluster_start = offset; + } + prev = codepoint; + }); + + if (cluster_start != std::string_view::npos) { + out.emplace_back(_text.substr(cluster_start)); + } + return out; +} + +std::vector UnicodeProcessor::split_by_punctuation() const { + if (_text.empty()) return {}; + + std::vector out; + size_t chunk_start = std::string_view::npos; + + for_each_utf8_unit([&](size_t offset, size_t byte_len, char32_t codepoint) { + if (is_sentence_punctuation(codepoint)) { + if (chunk_start != std::string_view::npos) { + const size_t end = offset + byte_len; + out.emplace_back(_text.substr(chunk_start, end - chunk_start)); + chunk_start = std::string_view::npos; + } + } else if (chunk_start == std::string_view::npos) { + chunk_start = offset; + } + }); + + if (chunk_start != std::string_view::npos) { + out.emplace_back(_text.substr(chunk_start)); + } + if (out.empty()) out.emplace_back(_text); + return out; +} + +} // namespace lazyllm From 1c7ee82837070870a32a334c1d2e11fe78a87b34 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 13 Feb 2026 15:00:06 +0800 Subject: [PATCH 23/40] text splitter base cpp finish --- csrc/core/include/doc_node.hpp | 1 + csrc/core/include/text_splitter_base.hpp | 15 ++- csrc/core/include/utils.hpp | 25 ++-- csrc/core/src/text_splitter_base.cpp | 144 +++++++++++++++++------ 4 files changed, 135 insertions(+), 50 deletions(-) diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index d91e18b09..825cf0b76 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -19,6 +19,7 @@ namespace lazyllm { enum class MetadataMode { ALL, EMBED, LLM, NONE }; +// TODO: Refactor docnode management from NodeTransform to "Parant" DocNode class DocNode { public: using Metadata = std::unordered_map; diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 65843436f..3be063325 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -47,10 +48,18 @@ class TextSplitterBase : public NodeTransform { std::vector transform(const DocNode* node) const override { if (node == nullptr) return {}; - return split_text(node->get_text_view(), get_node_metadata_size(node)); + auto chunks = split_text(node->get_text_view(), get_node_metadata_size(node)); + std::vector nodes; + nodes.reserve(chunks.size()); + for (const auto& chunk : chunks) { + DocNode chunk_node(chunk); + chunk_node.set_root_text(std::string(chunk)); + nodes.emplace_back(std::move(chunk_node)); + } + return nodes; } - std::vector split_text(const std::string_view& view, int metadata_size) const; + std::vector split_text(const std::string_view& view, int metadata_size) const; virtual void set_split_functions( const std::vector&, @@ -60,7 +69,7 @@ class TextSplitterBase : public NodeTransform { protected: virtual std::vector split_recursive(const std::string_view& view, const int chunk_size) const; - virtual std::vector _merge(std::vector splits, int chunk_size); + virtual std::vector merge_chunks(const std::vector& splits, int chunk_size) const; private: std::tuple, bool> split_by_functions(const std::string_view& text) const; diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 919364e48..477efe99b 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -10,16 +10,17 @@ namespace lazyllm { -// RAG system metadata keys -constexpr std::string_view RAG_KEY_KB_ID = "kb_id"; -constexpr std::string_view RAG_KEY_DOC_ID = "docid"; -constexpr std::string_view RAG_KEY_DOC_PATH = "lazyllm_doc_path"; -constexpr std::string_view RAG_KEY_DOC_FILE_NAME = "file_name"; -constexpr std::string_view RAG_KEY_DOC_FILE_TYPE = "file_type"; -constexpr std::string_view RAG_KEY_DOC_FILE_SIZE = "file_size"; -constexpr std::string_view RAG_KEY_DOC_CREATION_DATE = "creation_date"; -constexpr std::string_view RAG_KEY_DOC_LAST_MODIFIED_DATE = "last_modified_date"; -constexpr std::string_view RAG_KEY_DOC_LAST_ACCESSED_DATE = "last_accessed_date"; +struct RAGMetadataKeys { + inline static constexpr std::string_view KB_ID = "kb_id"; + inline static constexpr std::string_view DOC_ID = "docid"; + inline static constexpr std::string_view DOC_PATH = "lazyllm_doc_path"; + inline static constexpr std::string_view DOC_FILE_NAME = "file_name"; + inline static constexpr std::string_view DOC_FILE_TYPE = "file_type"; + inline static constexpr std::string_view DOC_FILE_SIZE = "file_size"; + inline static constexpr std::string_view DOC_CREATION_DATE = "creation_date"; + inline static constexpr std::string_view DOC_LAST_MODIFIED_DATE = "last_modified_date"; + inline static constexpr std::string_view DOC_LAST_ACCESSED_DATE = "last_accessed_date"; +}; inline std::string JoinLines(const std::vector& lines, char delim = '\n') { if (lines.empty()) return {}; @@ -79,4 +80,8 @@ inline std::string GenerateUUID() { return out; } +inline bool is_adjacent(const std::string_view& left, const std::string_view& right) { + return left.data() + left.size() == right.data(); +} + } // namespace lazyllm diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index 0bb9f6e9b..19eabcf3c 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -3,7 +3,30 @@ namespace lazyllm { -std::vector TextSplitterBase::split_text(const std::string_view& view, int metadata_size) const { +/* + * split_text + * ---------- + * Purpose: + * 1) Validate chunk budget after accounting for metadata tokens. + * 2) Recursively split the original text view into token-bounded SplitUnit pieces. + * 3) Merge the pieces into final chunk strings with overlap behavior aligned to Python implementation. + * + * Flow: + * 1) Compute effective_chunk_size = chunk_size - metadata_size. + * 2) Reject invalid/too-small budgets. + * 3) Call split_recursive(...) to produce SplitUnit sequence. + * 4) Call merge_chunks(...) to build final std::string chunks. + * + * Notes: + * - This function returns std::string chunks intentionally because current tokenizer + * encode/decode materializes strings in the merge path. + * - Ownership is explicit here to avoid dangling string_view in downstream DocNode construction. + * + * TODO: + * - After tokenizer supports true string_view encode/decode, migrate this path back to + * std::vector and remove eager string materialization. + */ +std::vector TextSplitterBase::split_text(const std::string_view& view, int metadata_size) const { if (view.empty()) return {}; int effective_chunk_size = _chunk_size - metadata_size; if (effective_chunk_size <= 0) { @@ -20,7 +43,7 @@ std::vector TextSplitterBase::split_text(const std::string_view& view, "your metadata to avoid this."); } auto splits = split_recursive(view, effective_chunk_size); - return _merge(splits, effective_chunk_size); + return merge_chunks(splits, effective_chunk_size); } std::vector TextSplitterBase::split_recursive( @@ -67,43 +90,87 @@ std::vector TextSplitterBase::split_text_while_keeping_separat std::vector result; size_t start = 0; const size_t sep_len = separator.size(); - while (start <= text.size()) { + while (start < text.size()) { const size_t idx = text.find(separator, start); - const size_t end = (idx == std::string_view::npos) ? text.size() : idx; - if (end > start) // Drop empty strings - result.emplace_back(text.substr(start, end - start)); - if (idx == std::string_view::npos) break; + if (idx == std::string_view::npos) { + result.emplace_back(text.substr(start)); + break; + } + + if (idx == start) { + start += sep_len; + continue; + } + + result.emplace_back(text.substr(start, idx + sep_len - start)); start = idx + sep_len; } return result; } -std::vector TextSplitterBase::_merge(std::vector splits, int chunk_size) { +/** + * @brief Build final chunks from token-sized split units while preserving overlap semantics. + * + * @details + * 1) Convert input SplitUnit views to owned strings (MergedSplit) for safe concatenation. + * 2) If the tail split exactly matches chunk_size and overlap > 0: + * split it by token-halves via encode/decode, then push both halves back. + * 3) Iterate backward: + * Add previous split, or part of it, to current split as overlap. + * - If the previous split is small enough, prepend it fully. + * - Otherwise, prepend token-based overlap suffix from previous split. + * 4) Emit chunks in original order. + * + * @todo Replace eager string materialization once tokenizer encode/decode supports + * end-to-end zero-copy string_view operations. + */ +std::vector TextSplitterBase::merge_chunks(const std::vector& splits, int chunk_size) const { if (splits.empty()) return {}; - if (splits.size() == 1) return {splits.front().text}; - SplitUnit end_split = splits.back(); - if (end_split.token_size == chunk_size && _overlap > 0) { - splits.pop_back(); - auto text_tokens = encode(end_split.text); + struct MergedSplit { + std::string text; + bool is_sentence = false; + int token_size = 0; + + MergedSplit& operator+=(const MergedSplit& r) { + text += r.text; + is_sentence = is_sentence && r.is_sentence; + token_size += r.token_size; + return *this; + } + }; + std::vector merged_splits; + merged_splits.reserve(splits.size() + 2); + for (const auto& split : splits) + merged_splits.push_back(MergedSplit{std::string(split.view), split.is_sentence, split.token_size}); + + if (merged_splits.size() == 1) return {merged_splits.front().text}; + + if (merged_splits.back().token_size == chunk_size && _overlap > 0) { + MergedSplit end_split = merged_splits.back(); + merged_splits.pop_back(); + + auto text_tokens = _tokenizer->encode(end_split.text); const size_t half = text_tokens.size() / 2; - std::vector p_tokens(text_tokens.begin(), text_tokens.begin() + half); - std::vector n_tokens(text_tokens.begin() + half, text_tokens.end()); - std::string p_text = decode(p_tokens); - std::string n_text = decode(n_tokens); - splits.push_back(SplitUnit{p_text, end_split.is_sentence, get_token_size(p_text)}); - splits.push_back(SplitUnit{n_text, end_split.is_sentence, get_token_size(n_text)}); - end_split = splits.back(); + const auto split_it = text_tokens.begin() + static_cast::difference_type>(half); + std::vector prefix_tokens(text_tokens.begin(), split_it); + std::vector suffix_tokens(split_it, text_tokens.end()); + + std::string prefix_text = _tokenizer->decode(prefix_tokens); + std::string suffix_text = _tokenizer->decode(suffix_tokens); + merged_splits.push_back( + MergedSplit{prefix_text, end_split.is_sentence, get_token_size(prefix_text)}); + merged_splits.push_back( + MergedSplit{suffix_text, end_split.is_sentence, get_token_size(suffix_text)}); } - std::vector result; - for (int idx = static_cast(splits.size()) - 2; idx >= 0; --idx) { - SplitUnit start_split = splits[static_cast(idx)]; - if (start_split.token_size <= _overlap && - end_split.token_size <= chunk_size - _overlap) { - const bool is_sentence = start_split.is_sentence && end_split.is_sentence; - const int token_size = start_split.token_size + end_split.token_size; - end_split = SplitUnit{start_split.text + end_split.text, is_sentence, token_size}; + MergedSplit end_split = merged_splits.back(); + std::vector reversed_result; + reversed_result.reserve(merged_splits.size()); + for (auto idx = merged_splits.size() - 2; idx >= 0; --idx) { + const MergedSplit& start_split = merged_splits[idx]; + if (start_split.token_size <= _overlap && end_split.token_size <= chunk_size - _overlap) { + end_split += start_split; continue; } @@ -114,20 +181,23 @@ std::vector TextSplitterBase::_merge(std::vector splits, const int remaining_space = chunk_size - end_split.token_size; const int overlap_len = std::min({_overlap, remaining_space, start_split.token_size}); if (overlap_len > 0) { - auto start_tokens = encode(start_split.text); - std::vector overlap_tokens( - start_tokens.end() - overlap_len, start_tokens.end()); - std::string overlap_text = decode(overlap_tokens); - end_split = SplitUnit{overlap_text + end_split.text, end_split.is_sentence, - end_split.token_size + overlap_len}; + auto start_tokens = _tokenizer->encode(start_split.text); + std::vector overlap_tokens(start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = _tokenizer->decode(overlap_tokens); + + end_split = MergedSplit{ + overlap_text + end_split.text, + end_split.is_sentence, + end_split.token_size + overlap_len}; } - result.insert(result.begin(), end_split.text); + reversed_result.emplace_back(end_split.text); end_split = start_split; } - result.insert(result.begin(), end_split.text); - return result; + reversed_result.emplace_back(end_split.text); + std::reverse(reversed_result.begin(), reversed_result.end()); + return reversed_result; } } From 9ef9bd86d527cfd98264bb0759b06208f9108f3a Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 13 Feb 2026 15:02:46 +0800 Subject: [PATCH 24/40] keys --- csrc/adaptor/document_store.hpp | 4 ++-- csrc/binding/export_doc_node.cpp | 4 ++-- csrc/core/include/doc_node.hpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index fff7477b3..c735f23c5 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -55,8 +55,8 @@ class LAZYLLM_HIDDEN DocumentStore : public AdaptorBaseWrapper { DocNode::Children get_node_children(const DocNode* node) const { DocNode::Children out; - auto& kb_id = std::any_cast(node->_p_global_metadata->at(std::string(RAG_KEY_KB_ID))); - auto& doc_id = std::any_cast(node->_p_global_metadata->at(std::string(RAG_KEY_DOC_ID))); + auto& kb_id = std::any_cast(node->_p_global_metadata->at(std::string(RAGMetadataKeys::KB_ID))); + auto& doc_id = std::any_cast(node->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_ID))); auto& group_name = node->get_group_name(); for(auto& [current_group_name, group] : _node_groups_map) { if (group._parent != group_name) continue; diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 24c82aa8d..8370ad428 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -43,9 +43,9 @@ lazyllm::DocNode init( store_adaptor = lazyllm::DocumentStore::from_store(store, node_groups_map); auto kb_id = std::any_cast((*global_metadata).at( - std::string(lazyllm::RAG_KEY_KB_ID))); + std::string(lazyllm::RAGMetadataKeys::KB_ID))); auto doc_id = std::any_cast((*global_metadata).at( - std::string(lazyllm::RAG_KEY_DOC_ID))); + std::string(lazyllm::RAGMetadataKeys::DOC_ID))); if (const auto* parent_uid = std::get_if(&*parent)) { p_parent_node = std::any_cast(store_adaptor->call("get_node", diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 825cf0b76..72b220c8c 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -145,10 +145,10 @@ class DocNode { _excluded_llm_metadata_keys = keys; } std::string get_doc_path() const { - return std::any_cast(get_root_node()->_p_global_metadata->at(std::string(RAG_KEY_DOC_PATH))); + return std::any_cast(get_root_node()->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_PATH))); } void set_doc_path(const std::string& path) { - get_root_node()->_p_global_metadata->operator[](std::string(RAG_KEY_DOC_PATH)) = path; + get_root_node()->_p_global_metadata->operator[](std::string(RAGMetadataKeys::DOC_PATH)) = path; } auto py_get_children_uid() const { auto children = py_get_children(); From 068ca986721155e36b7b0223a230444709b3dc00 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 13 Feb 2026 16:18:05 +0800 Subject: [PATCH 25/40] export --- csrc/adaptor/document_store.hpp | 2 +- csrc/binding/export_node_transform.cpp | 72 +++++-- csrc/binding/export_text_splitter_base.cpp | 239 +++++++++++++++------ csrc/core/include/map_params.hpp | 4 +- csrc/core/include/text_splitter_base.hpp | 20 +- csrc/core/src/text_splitter_base.cpp | 12 +- 6 files changed, 258 insertions(+), 91 deletions(-) diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index c735f23c5..c2411a26a 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -57,7 +57,7 @@ class LAZYLLM_HIDDEN DocumentStore : public AdaptorBaseWrapper { DocNode::Children out; auto& kb_id = std::any_cast(node->_p_global_metadata->at(std::string(RAGMetadataKeys::KB_ID))); auto& doc_id = std::any_cast(node->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_ID))); - auto& group_name = node->get_group_name(); + auto& group_name = node->_group_name; for(auto& [current_group_name, group] : _node_groups_map) { if (group._parent != group_name) continue; if (!std::any_cast(call("is_group_active", {{"group", current_group_name}}))) continue; diff --git a/csrc/binding/export_node_transform.cpp b/csrc/binding/export_node_transform.cpp index 82d527553..c67f185fd 100644 --- a/csrc/binding/export_node_transform.cpp +++ b/csrc/binding/export_node_transform.cpp @@ -8,43 +8,87 @@ namespace { -std::vector cast_documents(py::object documents) { - std::vector docs; - if (py::isinstance(documents)) { - for (auto item : documents) docs.push_back(py::cast(item)); - } else { - docs.push_back(documents.cast()); +class PyNodeTransform : public lazyllm::NodeTransform { +public: + using lazyllm::NodeTransform::NodeTransform; + + std::vector transform(const lazyllm::DocNode* document) const override { + py::gil_scoped_acquire gil; + py::function overload = py::get_override(static_cast(this), "transform"); + if (!overload) throw std::runtime_error("NodeTransform.transform is not implemented."); + + py::object result = overload(document); + if (!py::isinstance(result)) { + throw std::runtime_error("NodeTransform.transform must return a sequence."); + } + + std::vector out; + for (auto item : result) { + py::object obj = py::reinterpret_borrow(item); + if (obj.is_none()) continue; + + if (py::isinstance(obj)) { + std::string text = obj.cast(); + if (text.empty()) continue; + out.emplace_back(std::move(text)); + } else { + out.emplace_back(obj.cast()); + } + } + return out; } - return docs; -} +}; } // namespace void exportNodeTransform(py::module& m) { - py::class_(m, "NodeTransform") + py::class_(m, "NodeTransform") .def(py::init(), py::arg("num_workers") = 0) .def("batch_forward", [](lazyllm::NodeTransform& self, py::object documents, - const std::string& node_group) { - auto docs = cast_documents(documents); + const std::string& node_group, + py::object /*ref_path*/, + py::kwargs /*kwargs*/) { + std::vector docs; + if (py::isinstance(documents)) { + for (auto item : documents) docs.push_back(py::cast(item)); + } else + docs.push_back(documents.cast()); return self.batch_forward(docs, node_group); }, py::arg("documents"), py::arg("node_group"), - py::kw_only(), + py::arg("ref_path") = py::none(), py::return_value_policy::reference ) + .def("transform", + [](const lazyllm::NodeTransform& self, lazyllm::DocNode* document, py::kwargs /*kwargs*/) { + if (document == nullptr) return std::vector{}; + return self.transform(document); + }, + py::arg("document") + ) + .def("__call__", + [](const lazyllm::NodeTransform& self, lazyllm::DocNode* node, py::kwargs /*kwargs*/) { + if (node == nullptr) return std::vector{}; + return self(*node); + }, + py::arg("node") + ) .def( "with_name", [](lazyllm::NodeTransform& self, py::object name, bool copy) -> lazyllm::NodeTransform& { + (void)copy; if (name.is_none()) return self; - self.with_name(name.cast(), copy); + self._name = name.cast(); return self; }, py::arg("name"), py::kw_only(), py::arg("copy") = true, py::return_value_policy::reference - ); + ) + .def_readwrite("_name", &lazyllm::NodeTransform::_name) + .def_property_readonly("_number_workers", &lazyllm::NodeTransform::worker_num); } diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp index 5010e3bbb..ea99b3d11 100644 --- a/csrc/binding/export_text_splitter_base.cpp +++ b/csrc/binding/export_text_splitter_base.cpp @@ -1,108 +1,221 @@ #include "lazyllm.hpp" -#include "adaptor_base_wrapper.hpp" #include "text_splitter_base.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include #include #include namespace { -class PyTokenizer final : public lazyllm::Tokenizer, public lazyllm::AdaptorBaseWrapper { +using SplitFn = lazyllm::TextSplitterBase::SplitFn; + +std::any py_to_any(const py::handle& value) { + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + throw std::runtime_error("Unsupported default parameter type."); +} + +py::object any_to_py(const std::any& value) { + if (value.type() == typeid(bool)) return py::bool_(std::any_cast(value)); + if (value.type() == typeid(int)) return py::int_(std::any_cast(value)); + if (value.type() == typeid(double)) return py::float_(std::any_cast(value)); + if (value.type() == typeid(std::string)) return py::str(std::any_cast(value)); + return py::none(); +} + +class PyTokenizer final : public Tokenizer { public: - explicit PyTokenizer(const py::object& obj) : AdaptorBaseWrapper(obj) {} + enum class Mode { + Generic, + HuggingFace + }; + + explicit PyTokenizer(py::object obj, Mode mode = Mode::Generic) + : _obj(std::move(obj)), _mode(mode) {} std::vector encode(const std::string_view& text) const override { - auto result = call("encode", {{"text", std::string(text)}}); - return std::any_cast>(result); + py::gil_scoped_acquire gil; + py::object func = py::getattr(_obj, "encode", py::none()); + if (func.is_none()) throw std::runtime_error("Tokenizer missing method: encode"); + + py::object result; + if (_mode == Mode::HuggingFace) { + result = func(std::string(text), py::arg("add_special_tokens") = false); + } else { + result = func(std::string(text)); + } + return result.cast>(); } std::string decode(const std::vector& token_ids) const override { - auto result = call("decode", {{"token_ids", token_ids}}); - return std::any_cast(result); + py::gil_scoped_acquire gil; + py::object func = py::getattr(_obj, "decode", py::none()); + if (func.is_none()) throw std::runtime_error("Tokenizer missing method: decode"); + + py::object result; + if (_mode == Mode::HuggingFace) { + result = func(token_ids, py::arg("skip_special_tokens") = true); + } else { + result = func(token_ids); + } + return result.cast(); } private: - std::any call_impl( - const std::string& func_name, - const py::object& func, - const std::unordered_map& args) const override - { - if (func.is_none()) { - throw std::runtime_error("Tokenizer missing method: " + func_name); - } - if (func_name == "encode") { - const auto& text = std::any_cast(args.at("text")); - py::object result = func(text); - return std::any(result.cast>()); - } - if (func_name == "decode") { - const auto& ids = std::any_cast&>(args.at("token_ids")); - py::object result = func(ids); - return std::any(result.cast()); - } - throw std::runtime_error("Unknown tokenizer method: " + func_name); + py::object _obj; + Mode _mode; +}; + +class PyTextSplitterBase final : public lazyllm::TextSplitterBase { +public: + using lazyllm::TextSplitterBase::TextSplitterBase; + + std::vector transform(const lazyllm::DocNode* node) const override { + PYBIND11_OVERRIDE( + std::vector, + lazyllm::TextSplitterBase, + transform, + node + ); } }; } // namespace void exportTextSpliterBase(py::module& m) { - py::class_(m, "TextSplitterBase") - .def(py::init< - std::optional, - std::optional, - std::optional, - std::optional>(), + py::class_(m, "_TextSplitterBase") + .def(py::init([]( + std::optional chunk_size, + std::optional overlap, + std::optional num_workers, + py::object encoding_name) { + return std::make_unique( + chunk_size, overlap, num_workers, encoding_name.is_none() ? "gpt2" : encoding_name.cast()); + }), py::arg("chunk_size") = py::none(), py::arg("overlap") = py::none(), py::arg("num_workers") = py::none(), py::arg("encoding_name") = py::none() ) - .def("split_text", &lazyllm::TextSplitterBase::split_text, - py::arg("text"), py::arg("metadata_size")) - .def("from_tiktoken_encoding", &lazyllm::TextSplitterBase::from_tiktoken_encoding, - py::arg("encoding_name"), py::return_value_policy::reference) + .def("split_text", + [](const lazyllm::TextSplitterBase& self, const std::string& text, int metadata_size) { + if (text.empty()) return std::vector{""}; + return self.split_text(text, metadata_size); + }, + py::arg("text"), py::arg("metadata_size") + ) + .def_static("split_text_keep_separator", + [](const std::string& text, const std::string& separator) { + auto views = lazyllm::TextSplitterBase::split_text_while_keeping_separator(text, separator); + std::vector out; + out.reserve(views.size()); + for (auto view : views) out.emplace_back(view); + return out; + }, + py::arg("text"), py::arg("separator") + ) + .def("from_tiktoken_encoder", + [](lazyllm::TextSplitterBase& self, + const std::string& encoding_name, + py::object model_name, + py::object /*allowed_special*/, + py::object /*disallowed_special*/, + py::kwargs /*kwargs*/) -> lazyllm::TextSplitterBase& { + if (model_name.is_none()) { + return self.from_tiktoken_encoder(encoding_name, std::nullopt); + } + return self.from_tiktoken_encoder(encoding_name, model_name.cast()); + }, + py::arg("encoding_name") = "gpt2", + py::arg("model_name") = py::none(), + py::arg("allowed_special") = py::none(), + py::arg("disallowed_special") = "all", + py::return_value_policy::reference + ) + .def("from_tiktoken_encoding", + [](lazyllm::TextSplitterBase& self, const std::string& encoding_name) -> lazyllm::TextSplitterBase& { + return self.from_tiktoken_encoder(encoding_name, std::nullopt); + }, + py::arg("encoding_name") = "gpt2", + py::return_value_policy::reference + ) .def("from_tokenizer", [](lazyllm::TextSplitterBase& self, py::object tokenizer) -> lazyllm::TextSplitterBase& { - auto adaptor = std::make_shared(tokenizer); - self.set_tokenizer(adaptor); + self.set_tokenizer(std::make_shared(std::move(tokenizer), PyTokenizer::Mode::Generic)); + return self; + }, + py::arg("tokenizer"), + py::return_value_policy::reference + ) + .def("from_huggingface_tokenizer", + [](lazyllm::TextSplitterBase& self, py::object tokenizer) -> lazyllm::TextSplitterBase& { + self.set_tokenizer(std::make_shared(std::move(tokenizer), PyTokenizer::Mode::HuggingFace)); return self; }, py::arg("tokenizer"), py::return_value_policy::reference ) + .def("set_split_fns", + [](lazyllm::TextSplitterBase& self, const std::vector& split_fns, py::object sub_split_fns) { + if (sub_split_fns.is_none()) { + self.set_split_functions(split_fns, std::nullopt); + return; + } + self.set_split_functions(split_fns, sub_split_fns.cast>()); + }, + py::arg("split_fns"), py::arg("sub_split_fns") = py::none() + ) + .def("add_split_fn", + [](lazyllm::TextSplitterBase& self, const SplitFn& split_fn, py::object index) { + if (index.is_none()) { + self.add_split_function(split_fn, std::nullopt); + return; + } + self.add_split_function(split_fn, index.cast()); + }, + py::arg("split_fn"), py::arg("index") = py::none() + ) + .def("clear_split_fns", &lazyllm::TextSplitterBase::clear_split_functions) .def_static("set_default", [](py::kwargs kwargs) { - std::unordered_map params; - for (auto item : kwargs) - params[py::cast(item.first)] = py::cast(item.second); - lazyllm::TextSplitterBase::set_default(params); + lazyllm::MapParams::MapType updates; + for (auto item : kwargs) { + const auto key = py::cast(item.first); + updates[key] = py_to_any(item.second); + } + lazyllm::TextSplitterBase::_default_params.set_default(updates); } ) .def_static("get_default", - [](py::object name) { - if (name.is_none()) return py::cast(lazyllm::TextSplitterBase::get_default()); - auto opt = lazyllm::TextSplitterBase::get_default(name.cast()); - if (!opt.has_value()) return py::none(); - return py::cast(*opt); + [](py::object param_name) -> py::object { + const auto defaults = lazyllm::TextSplitterBase::_default_params.get_default(); + if (param_name.is_none()) { + py::dict out; + for (const auto& [key, value] : defaults) { + out[py::str(key)] = any_to_py(value); + } + return py::object(std::move(out)); + } + + std::string key = param_name.cast(); + auto it = defaults.find(key); + if (it == defaults.end() && key == "num_workers") it = defaults.find("worker_num"); + else if (it == defaults.end() && key == "worker_num") it = defaults.find("num_workers"); + if (it == defaults.end()) return py::none(); + return any_to_py(it->second); }, py::arg("param_name") = py::none() ) - .def_static("reset_default", &lazyllm::TextSplitterBase::reset_default); - - py::class_(m, "_TokenTextSplitter") - .def(py::init< - std::optional, - std::optional, - std::optional, - std::optional>(), - py::arg("chunk_size") = py::none(), - py::arg("overlap") = py::none(), - py::arg("num_workers") = py::none(), - py::arg("encoding_name") = py::none() - ); - - m.def("split_text_keep_separator", &lazyllm::split_text_keep_separator, - py::arg("text"), py::arg("separator")); + .def_static("reset_default", []() { lazyllm::TextSplitterBase::_default_params.reset_default(); }); } diff --git a/csrc/core/include/map_params.hpp b/csrc/core/include/map_params.hpp index ed0c7bbc6..0659dca93 100644 --- a/csrc/core/include/map_params.hpp +++ b/csrc/core/include/map_params.hpp @@ -21,7 +21,7 @@ class MapParams { { if (value.has_value()) return *value; std::lock_guard guard(_lock); - auto it = _params.find(param_name); + auto it = _params.find(std::string(param_name)); if (it != _params.end()) return std::any_cast(it->second); return default_value; } @@ -29,7 +29,7 @@ class MapParams { template void set_default(const std::string_view& param_name, T value) { std::lock_guard guard(_lock); - _params[param_name] = std::any(value); + _params[std::string(param_name)] = std::any(value); } void set_default(const MapType& updates) { diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 3be063325..6b56ee600 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -30,6 +30,7 @@ struct SplitUnit { class TextSplitterBase : public NodeTransform { public: using SplitFn = std::function(const std::string&)>; + static MapParams _default_params; explicit TextSplitterBase( std::optional chunk_size, @@ -60,6 +61,19 @@ class TextSplitterBase : public NodeTransform { } std::vector split_text(const std::string_view& view, int metadata_size) const; + static std::vector split_text_while_keeping_separator( + const std::string_view& text, + const std::string_view& separator); + + TextSplitterBase& from_tiktoken_encoder( + const std::string& encoding_name = "gpt2", + const std::optional& model_name = std::nullopt) + { + _tokenizer = std::make_shared(model_name.value_or(encoding_name)); + return *this; + } + + void set_tokenizer(std::shared_ptr tokenizer) { _tokenizer = std::move(tokenizer); } virtual void set_split_functions( const std::vector&, @@ -73,9 +87,6 @@ class TextSplitterBase : public NodeTransform { private: std::tuple, bool> split_by_functions(const std::string_view& text) const; - std::vector split_text_while_keeping_separator( - const std::string_view& text, - const std::string_view& separator) const; int get_node_metadata_size(const DocNode* node) const { return std::max( @@ -87,9 +98,6 @@ class TextSplitterBase : public NodeTransform { return static_cast(_tokenizer->encode(view).size()); } -private: - static MapParams _default_params; - protected: int _chunk_size; int _overlap; diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index 19eabcf3c..a3b179289 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -3,6 +3,8 @@ namespace lazyllm { +MapParams TextSplitterBase::_default_params{}; + /* * split_text * ---------- @@ -54,12 +56,12 @@ std::vector TextSplitterBase::split_recursive( auto [views, is_sentence] = split_by_functions(view); std::vector splits; - for (const auto& view : views) { - const int seg_token_size = get_token_size(view); + for (const auto& segment_view : views) { + const int seg_token_size = get_token_size(segment_view); if (seg_token_size <= chunk_size) { - splits.emplace_back(view, is_sentence, seg_token_size); + splits.push_back({segment_view, is_sentence, seg_token_size}); } else { - auto new_splits = split_recursive(view, chunk_size); + auto new_splits = split_recursive(segment_view, chunk_size); splits.insert(splits.end(), new_splits.begin(), new_splits.end()); } } @@ -82,7 +84,7 @@ std::tuple, bool> TextSplitterBase::split_by_funct std::vector TextSplitterBase::split_text_while_keeping_separator( const std::string_view& text, - const std::string_view& separator) const + const std::string_view& separator) { if (text.empty()) return {}; else if (separator.empty()) return {text}; From 19e00dd00e39be6d2d39bdfb8a9235de111a97af Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 13 Feb 2026 17:01:25 +0800 Subject: [PATCH 26/40] sentence_splitter --- csrc/binding/export_sentence_splitter.cpp | 53 +++++++++++++ csrc/binding/lazyllm.cpp | 1 + csrc/binding/lazyllm.hpp | 1 + csrc/core/include/sentence_splitter.hpp | 35 +++++++++ csrc/core/include/text_splitter_base.hpp | 10 +-- csrc/core/include/utils.hpp | 18 +++++ csrc/core/src/sentence_splitter.cpp | 93 +++++++++++++++++++++++ csrc/core/src/text_splitter_base.cpp | 36 +++------ 8 files changed, 215 insertions(+), 32 deletions(-) create mode 100644 csrc/binding/export_sentence_splitter.cpp create mode 100644 csrc/core/include/sentence_splitter.hpp create mode 100644 csrc/core/src/sentence_splitter.cpp diff --git a/csrc/binding/export_sentence_splitter.cpp b/csrc/binding/export_sentence_splitter.cpp new file mode 100644 index 000000000..d1e68d2d9 --- /dev/null +++ b/csrc/binding/export_sentence_splitter.cpp @@ -0,0 +1,53 @@ +#include "lazyllm.hpp" + +#include "sentence_splitter.hpp" + +#include +#include +#include + +#include +#include + +namespace { + +class PySentenceSplitter final : public lazyllm::SentenceSplitter { +public: + using lazyllm::SentenceSplitter::SentenceSplitter; + + std::vector transform(const lazyllm::DocNode* node) const override { + PYBIND11_OVERRIDE( + std::vector, + lazyllm::SentenceSplitter, + transform, + node + ); + } +}; + +} // namespace + +void exportSentenceSplitter(py::module& m) { + py::class_< + lazyllm::SentenceSplitter, + lazyllm::TextSplitterBase, + PySentenceSplitter + >(m, "SentenceSplitter") + .def(py::init([]( + std::optional chunk_size, + std::optional chunk_overlap, + std::optional num_workers, + py::object encoding_name + ) { + return std::make_unique( + chunk_size, + chunk_overlap, + num_workers, + encoding_name.is_none() ? "gpt2" : encoding_name.cast()); + }), + py::arg("chunk_size") = py::none(), + py::arg("chunk_overlap") = py::none(), + py::arg("num_workers") = py::none(), + py::arg("encoding_name") = py::none() + ); +} diff --git a/csrc/binding/lazyllm.cpp b/csrc/binding/lazyllm.cpp index 96217a732..5c8ec0f4a 100644 --- a/csrc/binding/lazyllm.cpp +++ b/csrc/binding/lazyllm.cpp @@ -18,4 +18,5 @@ PYBIND11_MODULE(lazyllm_cpp, m) { exportDocNode(m); exportNodeTransform(m); exportTextSpliterBase(m); + exportSentenceSplitter(m); } diff --git a/csrc/binding/lazyllm.hpp b/csrc/binding/lazyllm.hpp index c9bc4c965..62a7564f7 100644 --- a/csrc/binding/lazyllm.hpp +++ b/csrc/binding/lazyllm.hpp @@ -13,3 +13,4 @@ void exportAddDocStr(pybind11::module& m); void exportDocNode(pybind11::module& m); void exportNodeTransform(pybind11::module& m); void exportTextSpliterBase(pybind11::module& m); +void exportSentenceSplitter(pybind11::module& m); diff --git a/csrc/core/include/sentence_splitter.hpp b/csrc/core/include/sentence_splitter.hpp new file mode 100644 index 000000000..25f639af5 --- /dev/null +++ b/csrc/core/include/sentence_splitter.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include + +#include "text_splitter_base.hpp" + +namespace lazyllm { + +class SentenceSplitter : public TextSplitterBase { +public: + explicit SentenceSplitter( + std::optional chunk_size, + std::optional chunk_overlap, + std::optional worker_num, + const std::string& encoding_name = "gpt2") + : TextSplitterBase(chunk_size, chunk_overlap, worker_num, encoding_name) {} + +protected: + std::vector merge_chunks(const std::vector& splits, int chunk_size) const override; + +private: + void close_chunk( + std::vector& chunks, + std::vector& cur_chunk, + int& cur_chunk_len, + bool& is_chunk_new) const; + + static std::string trim_ascii(std::string_view input); + static std::string join_parts(const std::vector& parts); +}; + +} // namespace lazyllm diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 6b56ee600..2988a9d1d 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -21,12 +21,6 @@ namespace lazyllm { -struct SplitUnit { - std::string_view view; - bool is_sentence = false; - int token_size = 0; -}; - class TextSplitterBase : public NodeTransform { public: using SplitFn = std::function(const std::string&)>; @@ -82,8 +76,8 @@ class TextSplitterBase : public NodeTransform { virtual void clear_split_functions() {} protected: - virtual std::vector split_recursive(const std::string_view& view, const int chunk_size) const; - virtual std::vector merge_chunks(const std::vector& splits, int chunk_size) const; + virtual std::vector split_recursive(const std::string_view& view, const int chunk_size) const; + virtual std::vector merge_chunks(const std::vector& splits, int chunk_size) const; private: std::tuple, bool> split_by_functions(const std::string_view& text) const; diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 477efe99b..42dfe8b0c 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -84,4 +84,22 @@ inline bool is_adjacent(const std::string_view& left, const std::string_view& ri return left.data() + left.size() == right.data(); } +struct ChunkView { + std::string_view view; + bool is_sentence = false; + int token_size = 0; +}; + +struct Chunk { + std::string text; + bool is_sentence = false; + int token_size = 0; + + Chunk& operator+=(const Chunk& r) { + text += r.text; + is_sentence = is_sentence && r.is_sentence; + token_size += r.token_size; + return *this; + } +}; } // namespace lazyllm diff --git a/csrc/core/src/sentence_splitter.cpp b/csrc/core/src/sentence_splitter.cpp new file mode 100644 index 000000000..926f2685b --- /dev/null +++ b/csrc/core/src/sentence_splitter.cpp @@ -0,0 +1,93 @@ +#include "sentence_splitter.hpp" + +#include +#include +#include +#include +#include +#include + +namespace lazyllm { + +std::string SentenceSplitter::trim_ascii(std::string_view input) { + size_t start = 0; + while (start < input.size() && std::isspace(static_cast(input[start]))) ++start; + + size_t end = input.size(); + while (end > start && std::isspace(static_cast(input[end - 1]))) --end; + + return std::string(input.substr(start, end - start)); +} + +std::string SentenceSplitter::join_parts(const std::vector& parts) { + size_t total_len = 0; + for (const auto& part : parts) total_len += part.text.size(); + + std::string out; + out.reserve(total_len); + for (const auto& part : parts) out += part.text; + return out; +} + +void SentenceSplitter::close_chunk( + std::vector& chunks, + std::vector& cur_chunk, + int& cur_chunk_len, + bool& is_chunk_new) const +{ + chunks.push_back(join_parts(cur_chunk)); + auto last_chunk = std::move(cur_chunk); + cur_chunk.clear(); + cur_chunk_len = 0; + is_chunk_new = true; + + int overlap_len = 0; + for (auto it = last_chunk.rbegin(); it != last_chunk.rend(); ++it) { + if (overlap_len + it->token_size > _overlap) break; + cur_chunk.push_back(*it); + overlap_len += it->token_size; + cur_chunk_len += it->token_size; + } + std::reverse(cur_chunk.begin(), cur_chunk.end()); +} + +std::vector SentenceSplitter::merge_chunks(const std::vector& splits, int chunk_size) const { + std::vector chunks; + std::vector cur_chunk; + int cur_chunk_len = 0; + bool is_chunk_new = true; + + size_t i = 0; + while (i < splits.size()) { + const auto& cur_split = splits[i]; + if (cur_split.token_size > chunk_size) { + throw std::runtime_error("Single token exceeded chunk size"); + } + + if (cur_chunk_len + cur_split.token_size > chunk_size && !is_chunk_new) { + close_chunk(chunks, cur_chunk, cur_chunk_len, is_chunk_new); + continue; + } + + if (cur_split.is_sentence || cur_chunk_len + cur_split.token_size <= chunk_size || is_chunk_new) { + cur_chunk_len += cur_split.token_size; + cur_chunk.push_back({std::string(cur_split.view), cur_split.is_sentence, cur_split.token_size}); + ++i; + is_chunk_new = false; + } else { + close_chunk(chunks, cur_chunk, cur_chunk_len, is_chunk_new); + } + } + + if (!is_chunk_new) chunks.push_back(join_parts(cur_chunk)); + + std::vector out; + out.reserve(chunks.size()); + for (const auto& chunk : chunks) { + auto stripped = trim_ascii(chunk); + if (!stripped.empty()) out.push_back(std::move(stripped)); + } + return out; +} + +} // namespace lazyllm diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index a3b179289..ac1e52d1c 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -48,14 +48,14 @@ std::vector TextSplitterBase::split_text(const std::string_view& vi return merge_chunks(splits, effective_chunk_size); } -std::vector TextSplitterBase::split_recursive( +std::vector TextSplitterBase::split_recursive( const std::string_view& view, const int chunk_size) const { int token_size = get_token_size(view); - if (token_size <= chunk_size) return {SplitUnit{view, true, token_size}}; + if (token_size <= chunk_size) return {ChunkView{view, true, token_size}}; auto [views, is_sentence] = split_by_functions(view); - std::vector splits; + std::vector splits; for (const auto& segment_view : views) { const int seg_token_size = get_token_size(segment_view); if (seg_token_size <= chunk_size) { @@ -126,30 +126,18 @@ std::vector TextSplitterBase::split_text_while_keeping_separat * @todo Replace eager string materialization once tokenizer encode/decode supports * end-to-end zero-copy string_view operations. */ -std::vector TextSplitterBase::merge_chunks(const std::vector& splits, int chunk_size) const { +std::vector TextSplitterBase::merge_chunks(const std::vector& splits, int chunk_size) const { if (splits.empty()) return {}; - struct MergedSplit { - std::string text; - bool is_sentence = false; - int token_size = 0; - - MergedSplit& operator+=(const MergedSplit& r) { - text += r.text; - is_sentence = is_sentence && r.is_sentence; - token_size += r.token_size; - return *this; - } - }; - std::vector merged_splits; + std::vector merged_splits; merged_splits.reserve(splits.size() + 2); for (const auto& split : splits) - merged_splits.push_back(MergedSplit{std::string(split.view), split.is_sentence, split.token_size}); + merged_splits.push_back(Chunk{std::string(split.view), split.is_sentence, split.token_size}); if (merged_splits.size() == 1) return {merged_splits.front().text}; if (merged_splits.back().token_size == chunk_size && _overlap > 0) { - MergedSplit end_split = merged_splits.back(); + Chunk end_split = merged_splits.back(); merged_splits.pop_back(); auto text_tokens = _tokenizer->encode(end_split.text); @@ -161,16 +149,16 @@ std::vector TextSplitterBase::merge_chunks(const std::vectordecode(prefix_tokens); std::string suffix_text = _tokenizer->decode(suffix_tokens); merged_splits.push_back( - MergedSplit{prefix_text, end_split.is_sentence, get_token_size(prefix_text)}); + Chunk{prefix_text, end_split.is_sentence, get_token_size(prefix_text)}); merged_splits.push_back( - MergedSplit{suffix_text, end_split.is_sentence, get_token_size(suffix_text)}); + Chunk{suffix_text, end_split.is_sentence, get_token_size(suffix_text)}); } - MergedSplit end_split = merged_splits.back(); + Chunk end_split = merged_splits.back(); std::vector reversed_result; reversed_result.reserve(merged_splits.size()); for (auto idx = merged_splits.size() - 2; idx >= 0; --idx) { - const MergedSplit& start_split = merged_splits[idx]; + const Chunk& start_split = merged_splits[idx]; if (start_split.token_size <= _overlap && end_split.token_size <= chunk_size - _overlap) { end_split += start_split; continue; @@ -187,7 +175,7 @@ std::vector TextSplitterBase::merge_chunks(const std::vector overlap_tokens(start_tokens.end() - overlap_len, start_tokens.end()); std::string overlap_text = _tokenizer->decode(overlap_tokens); - end_split = MergedSplit{ + end_split = Chunk{ overlap_text + end_split.text, end_split.is_sentence, end_split.token_size + overlap_len}; From e0c3acca3b750bc061ed5c40c014ada9a3d0d4c7 Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 24 Feb 2026 10:25:54 +0800 Subject: [PATCH 27/40] compile_options --- csrc/CMakeLists.txt | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 8756f2758..e533b90b0 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -4,10 +4,6 @@ project(LazyLLMCPP LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -add_compile_options( - -Werror - -Wshadow -) # Third party libs include(cmake/third_party.cmake) @@ -20,6 +16,7 @@ target_include_directories(lazyllm_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/core/ target_link_libraries(lazyllm_core PUBLIC xxhash) target_link_libraries(lazyllm_core PUBLIC tiktoken) target_link_libraries(lazyllm_core PUBLIC utf8proc) +target_compile_options(lazyllm_core PRIVATE -Werror -Wshadow) # Config lazyllm_adaptor lib which maintains callback invocations. file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS @@ -27,6 +24,7 @@ file(GLOB_RECURSE LAZYLLM_ADAPTOR_SOURCES CONFIGURE_DEPENDS add_library(lazyllm_adaptor STATIC ${LAZYLLM_ADAPTOR_SOURCES}) target_include_directories(lazyllm_adaptor PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/adaptor) target_link_libraries(lazyllm_adaptor PUBLIC pybind11::headers Python3::Python lazyllm_core) +target_compile_options(lazyllm_adaptor PRIVATE -Werror -Wshadow) # Config lazyllm_cpp lib with binding infomations. file(GLOB_RECURSE LAZYLLM_BINDING_SOURCES CONFIGURE_DEPENDS @@ -34,6 +32,7 @@ file(GLOB_RECURSE LAZYLLM_BINDING_SOURCES CONFIGURE_DEPENDS set(INTERFACE_TARGET_NAME lazyllm_cpp) pybind11_add_module(${INTERFACE_TARGET_NAME} ${LAZYLLM_BINDING_SOURCES}) target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core lazyllm_adaptor) +target_compile_options(${INTERFACE_TARGET_NAME} PRIVATE -Werror -Wshadow) if (CMAKE_BUILD_TYPE STREQUAL "Debug") # SHOW_SYMBOL From 06aa5861bd29447baf385a1ae2fc7722471e33d8 Mon Sep 17 00:00:00 2001 From: yzh Date: Tue, 24 Feb 2026 14:55:58 +0800 Subject: [PATCH 28/40] tests in cpp side --- csrc/binding/lazyllm.cpp | 3 +- csrc/core/include/adaptor_base.hpp | 5 - csrc/core/include/doc_node.hpp | 6 +- csrc/core/src/text_splitter_base.cpp | 4 +- csrc/tests/test_adaptor_base.cpp | 35 ++++++ csrc/tests/test_doc_node.cpp | 156 +++++++++++++++++++++++-- csrc/tests/test_map_params.cpp | 50 ++++++++ csrc/tests/test_node_transform.cpp | 125 ++++++++++++++++++++ csrc/tests/test_sentence_splitter.cpp | 52 +++++++++ csrc/tests/test_text_splitter_base.cpp | 111 ++++++++++++++++++ csrc/tests/test_thread_pool.cpp | 25 ++++ csrc/tests/test_tokenizer.cpp | 56 +++++++++ csrc/tests/test_unicode_processor.cpp | 41 +++++++ csrc/tests/test_utils.cpp | 55 +++++++++ 14 files changed, 706 insertions(+), 18 deletions(-) create mode 100644 csrc/tests/test_adaptor_base.cpp create mode 100644 csrc/tests/test_map_params.cpp create mode 100644 csrc/tests/test_node_transform.cpp create mode 100644 csrc/tests/test_sentence_splitter.cpp create mode 100644 csrc/tests/test_text_splitter_base.cpp create mode 100644 csrc/tests/test_thread_pool.cpp create mode 100644 csrc/tests/test_tokenizer.cpp create mode 100644 csrc/tests/test_unicode_processor.cpp create mode 100644 csrc/tests/test_utils.cpp diff --git a/csrc/binding/lazyllm.cpp b/csrc/binding/lazyllm.cpp index 5c8ec0f4a..db701edda 100644 --- a/csrc/binding/lazyllm.cpp +++ b/csrc/binding/lazyllm.cpp @@ -11,10 +11,11 @@ PYBIND11_MODULE(lazyllm_cpp, m) { m.doc() = "LazyLLM CPP Module."; exportAddDocStr(m); - // prevent document generation + // Prevent document generation py::options options; options.disable_function_signatures(); + // Export classes exportDocNode(m); exportNodeTransform(m); exportTextSpliterBase(m); diff --git a/csrc/core/include/adaptor_base.hpp b/csrc/core/include/adaptor_base.hpp index 6bb729403..602dcb2d3 100644 --- a/csrc/core/include/adaptor_base.hpp +++ b/csrc/core/include/adaptor_base.hpp @@ -12,11 +12,6 @@ namespace lazyllm { -struct Arg { - std::string name; - std::any value; -}; - class AdaptorBase { public: virtual ~AdaptorBase() = default; diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 72b220c8c..3380c0292 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -133,13 +133,15 @@ class DocNode { _children[group_name] = children_group; } std::set get_excluded_embed_metadata_keys() const { - return SetUnion(get_root_node()->get_excluded_embed_metadata_keys(), _excluded_embed_metadata_keys); + if (_p_parent_node == nullptr) return _excluded_embed_metadata_keys; + return SetUnion(_p_parent_node->get_excluded_embed_metadata_keys(), _excluded_embed_metadata_keys); } void set_excluded_embed_metadata_keys(const std::set& keys) { _excluded_embed_metadata_keys = keys; } std::set get_excluded_llm_metadata_keys() const { - return SetUnion(get_root_node()->get_excluded_llm_metadata_keys(), _excluded_llm_metadata_keys); + if (_p_parent_node == nullptr) return _excluded_llm_metadata_keys; + return SetUnion(_p_parent_node->get_excluded_llm_metadata_keys(), _excluded_llm_metadata_keys); } void set_excluded_llm_metadata_keys(const std::set& keys) { _excluded_llm_metadata_keys = keys; diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index ac1e52d1c..00361edcb 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -157,8 +157,8 @@ std::vector TextSplitterBase::merge_chunks(const std::vector reversed_result; reversed_result.reserve(merged_splits.size()); - for (auto idx = merged_splits.size() - 2; idx >= 0; --idx) { - const Chunk& start_split = merged_splits[idx]; + for (int idx = static_cast(merged_splits.size()) - 2; idx >= 0; --idx) { + const Chunk& start_split = merged_splits[static_cast(idx)]; if (start_split.token_size <= _overlap && end_split.token_size <= chunk_size - _overlap) { end_split += start_split; continue; diff --git a/csrc/tests/test_adaptor_base.cpp b/csrc/tests/test_adaptor_base.cpp new file mode 100644 index 000000000..a72b9b897 --- /dev/null +++ b/csrc/tests/test_adaptor_base.cpp @@ -0,0 +1,35 @@ +#include + +#include +#include +#include + +#include "adaptor_base.hpp" + +namespace { + +class EchoAdaptor final : public lazyllm::AdaptorBase { +public: + mutable int call_count = 0; + + std::any call( + const std::string& func_name, + const std::unordered_map& args) const override + { + ++call_count; + if (func_name == "sum") { + return std::any_cast(args.at("left")) + std::any_cast(args.at("right")); + } + return func_name; + } +}; + +} // namespace + +TEST(AdaptorBase, DerivedCallReceivesArgsAndReturnsAny) { + EchoAdaptor adaptor; + const auto result = adaptor.call("sum", {{"left", 3}, {"right", 4}}); + + EXPECT_EQ(std::any_cast(result), 7); + EXPECT_EQ(adaptor.call_count, 1); +} diff --git a/csrc/tests/test_doc_node.cpp b/csrc/tests/test_doc_node.cpp index ff3dead26..d209ba341 100644 --- a/csrc/tests/test_doc_node.cpp +++ b/csrc/tests/test_doc_node.cpp @@ -1,16 +1,156 @@ #include -#include "doc_node.h" +#include +#include +#include +#include +#include -TEST(DocNode, DefaultEmpty) { - lazyllm::DocNode node; - EXPECT_EQ(node.get_text(), ""); -} +#include "adaptor_base.hpp" +#include "doc_node.hpp" +#include "utils.hpp" + +namespace { + +class CountingAdaptor final : public lazyllm::AdaptorBase { +public: + mutable int call_count = 0; + lazyllm::DocNode::Children to_return; + + std::any call( + const std::string& func_name, + const std::unordered_map& args) const override + { + ++call_count; + EXPECT_EQ(func_name, "get_node_children"); + EXPECT_TRUE(args.find("node") != args.end()); + return to_return; + } +}; + +} // namespace -TEST(DocNode, SetGet) { - lazyllm::DocNode node("hello"); +TEST(DocNode, ConstructorAndTextHashUpdate) { + auto global_meta = std::make_shared(); + lazyllm::DocNode node("hello", "group", "fixed-uid", nullptr, {}, global_meta); + + EXPECT_EQ(node.get_uid(), "fixed-uid"); + EXPECT_EQ(node._group_name, "group"); EXPECT_EQ(node.get_text(), "hello"); - node.set_text("world"); + const size_t old_hash = node.get_text_hash(); + node.set_text_view("world"); + EXPECT_NE(node.get_text_hash(), old_hash); EXPECT_EQ(node.get_text(), "world"); } + +TEST(DocNode, MetadataModesAndTextRender) { + auto global_meta = std::make_shared(); + lazyllm::DocNode::Metadata metadata{ + {"alpha", std::string("A")}, + {"beta", std::string("B")}, + }; + lazyllm::DocNode node("body", "", "", nullptr, metadata, global_meta); + + node.set_excluded_embed_metadata_keys({"beta"}); + node.set_excluded_llm_metadata_keys({"alpha"}); + + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::ALL), "alpha:A\nbeta:B"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::EMBED), "alpha:A"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::LLM), "beta:B"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::NONE), ""); + + EXPECT_EQ(node.get_text(lazyllm::MetadataMode::NONE), "body"); + EXPECT_EQ(node.get_text(lazyllm::MetadataMode::EMBED), "alpha:A\n\nbody"); +} + +TEST(DocNode, ParentChildrenAndUidViews) { + auto global_meta = std::make_shared(); + lazyllm::DocNode root("root", "root_group", "root_uid", nullptr, {}, global_meta); + lazyllm::DocNode child("child", "child_group", "child_uid", &root, {}, global_meta); + + root.set_children_group("split", {&child}); + + EXPECT_EQ(child.get_root_node(), &root); + EXPECT_EQ(child.get_parent_uid(), "root_uid"); + EXPECT_TRUE(root.is_children_group_exists("split")); + + const auto child_ids = root.py_get_children_uid(); + ASSERT_TRUE(child_ids.find("split") != child_ids.end()); + ASSERT_EQ(child_ids.at("split").size(), 1u); + EXPECT_EQ(child_ids.at("split")[0], "child_uid"); +} + +TEST(DocNode, StoreBackedChildrenAreCached) { + auto global_meta = std::make_shared(); + lazyllm::DocNode parent("p", "", "parent", nullptr, {}, global_meta); + lazyllm::DocNode child("c", "", "child", &parent, {}, global_meta); + + auto adaptor = std::make_shared(); + adaptor->to_return["cached"] = {&child}; + parent.set_store(adaptor); + + const auto first = parent.py_get_children(); + const auto second = parent.py_get_children(); + + EXPECT_EQ(adaptor->call_count, 1); + ASSERT_TRUE(first.find("cached") != first.end()); + ASSERT_EQ(first.at("cached").size(), 1u); + EXPECT_EQ(first.at("cached")[0], &child); + EXPECT_EQ(second.at("cached")[0], &child); +} + +TEST(DocNode, GlobalDocPathAndExclusionInheritance) { + auto global_meta = std::make_shared(); + (*global_meta)[std::string(lazyllm::RAGMetadataKeys::DOC_PATH)] = std::string("/tmp/a.txt"); + + lazyllm::DocNode root("root", "", "root", nullptr, {}, global_meta); + lazyllm::DocNode child("child", "", "child", &root, {}, global_meta); + + root.set_excluded_embed_metadata_keys({"root_embed"}); + child.set_excluded_embed_metadata_keys({"child_embed"}); + root.set_excluded_llm_metadata_keys({"root_llm"}); + child.set_excluded_llm_metadata_keys({"child_llm"}); + + EXPECT_EQ(child.get_doc_path(), "/tmp/a.txt"); + child.set_doc_path("/tmp/b.txt"); + EXPECT_EQ(root.get_doc_path(), "/tmp/b.txt"); + + EXPECT_EQ( + child.get_excluded_embed_metadata_keys(), + (std::set{"child_embed", "root_embed"})); + EXPECT_EQ( + child.get_excluded_llm_metadata_keys(), + (std::set{"child_llm", "root_llm"})); +} + +TEST(DocNode, EmbeddingHelpers) { + auto global_meta = std::make_shared(); + lazyllm::DocNode node("text", "", "node", nullptr, {}, global_meta); + + EXPECT_THROW(node.embedding_keys_undone({}), std::runtime_error); + + node.set_embedding_vec("done", {1.0, 2.0}); + const auto missing = node.embedding_keys_undone({"done", "todo"}); + EXPECT_EQ(missing, (std::set{"todo"})); + + node.py_do_embedding({ + {"len_embedding", [](const std::string& input) { + return std::vector{static_cast(input.size())}; + }} + }); + ASSERT_TRUE(node._embedding_vecs.find("len_embedding") != node._embedding_vecs.end()); + ASSERT_EQ(node._embedding_vecs["len_embedding"].size(), 1u); + EXPECT_EQ(node._embedding_vecs["len_embedding"][0], 6.0); +} + +TEST(DocNode, EqualityUsesUid) { + auto global_meta = std::make_shared(); + lazyllm::DocNode lhs("left", "", "same_uid", nullptr, {}, global_meta); + lazyllm::DocNode rhs("right", "", "same_uid", nullptr, {}, global_meta); + lazyllm::DocNode other("other", "", "other_uid", nullptr, {}, global_meta); + + EXPECT_TRUE(lhs == rhs); + EXPECT_FALSE(lhs != rhs); + EXPECT_TRUE(lhs != other); +} diff --git a/csrc/tests/test_map_params.cpp b/csrc/tests/test_map_params.cpp new file mode 100644 index 000000000..c97af2e78 --- /dev/null +++ b/csrc/tests/test_map_params.cpp @@ -0,0 +1,50 @@ +#include + +#include +#include + +#include "map_params.hpp" + +TEST(MapParams, ExplicitValueHasHighestPriority) { + lazyllm::MapParams params; + params.set_default("chunk_size", 1024u); + + const auto value = params.get_param_value("chunk_size", 4096u, 2048u); + EXPECT_EQ(value, 4096u); +} + +TEST(MapParams, ReadsStoredDefaultAndFallback) { + lazyllm::MapParams params; + params.set_default("worker_num", 8u); + + EXPECT_EQ(params.get_param_value("worker_num", std::nullopt, 4u), 8u); + EXPECT_EQ(params.get_param_value("missing", std::nullopt, 4u), 4u); +} + +TEST(MapParams, BulkSetAndTypedGetDefault) { + lazyllm::MapParams params; + lazyllm::MapParams::MapType updates{ + {"encoding", std::string("gpt2")}, + {"overlap", 128u}, + }; + params.set_default(updates); + + const auto encoding = params.get_default("encoding"); + const auto overlap = params.get_default("overlap"); + const auto missing = params.get_default("chunk_size"); + + ASSERT_TRUE(encoding.has_value()); + ASSERT_TRUE(overlap.has_value()); + EXPECT_EQ(*encoding, "gpt2"); + EXPECT_EQ(*overlap, 128u); + EXPECT_FALSE(missing.has_value()); +} + +TEST(MapParams, ResetDefaultClearsState) { + lazyllm::MapParams params; + params.set_default("k", 1); + ASSERT_FALSE(params.get_default().empty()); + + params.reset_default(); + EXPECT_TRUE(params.get_default().empty()); +} diff --git a/csrc/tests/test_node_transform.cpp b/csrc/tests/test_node_transform.cpp new file mode 100644 index 000000000..db338b662 --- /dev/null +++ b/csrc/tests/test_node_transform.cpp @@ -0,0 +1,125 @@ +#include + +#include +#include +#include + +#include "doc_node.hpp" +#include "node_transform.hpp" + +namespace { + +class PairTransform final : public lazyllm::NodeTransform { +public: + explicit PairTransform(int worker_num = 0) : lazyllm::NodeTransform(worker_num) {} + + mutable int transform_calls = 0; + + std::vector transform(const lazyllm::DocNode* node) const override { + ++transform_calls; + const auto shared_meta = node->get_root_node()->_p_global_metadata; + return { + lazyllm::DocNode( + std::string(node->get_text_view()) + "_A", + "", + "", + nullptr, + {}, + shared_meta), + lazyllm::DocNode( + std::string(node->get_text_view()) + "_B", + "", + "", + nullptr, + {}, + shared_meta), + }; + } +}; + +class SingleTransform final : public lazyllm::NodeTransform { +public: + explicit SingleTransform(int worker_num) : lazyllm::NodeTransform(worker_num) {} + + std::vector transform(const lazyllm::DocNode* node) const override { + const auto shared_meta = node->get_root_node()->_p_global_metadata; + return { + lazyllm::DocNode( + std::string(node->get_text_view()) + "_child", + "", + "", + nullptr, + {}, + shared_meta) + }; + } +}; + +} // namespace + +TEST(NodeTransform, BatchForwardCreatesChildrenAndSetsParentGroup) { + auto shared_meta = std::make_shared(); + std::vector roots; + roots.reserve(2); + roots.emplace_back("left", "", "left_uid", nullptr, lazyllm::DocNode::Metadata(), shared_meta); + roots.emplace_back("right", "", "right_uid", nullptr, lazyllm::DocNode::Metadata(), shared_meta); + + std::vector root_ptrs{&roots[0], &roots[1]}; + PairTransform transform(0); + const auto direct = transform(roots[0]); + EXPECT_EQ(direct.size(), 2u); + + const auto children = transform.batch_forward(root_ptrs, "split"); + ASSERT_EQ(children.size(), 4u); + + for (const auto* child : children) { + EXPECT_NE(child->get_parent_node(), nullptr); + EXPECT_EQ(child->_group_name, "split"); + } + for (const auto& root : roots) { + EXPECT_TRUE(root.is_children_group_exists("split")); + EXPECT_EQ(root.py_get_children().at("split").size(), 2u); + } +} + +TEST(NodeTransform, BatchForwardSkipsExistingGroup) { + auto shared_meta = std::make_shared(); + lazyllm::DocNode root("root", "", "uid", nullptr, lazyllm::DocNode::Metadata(), shared_meta); + root.set_children_group("already", {}); + std::vector root_ptrs{&root}; + + PairTransform transform(0); + const auto children = transform.batch_forward(root_ptrs, "already"); + + EXPECT_TRUE(children.empty()); + EXPECT_EQ(transform.transform_calls, 0); +} + +TEST(NodeTransform, BatchForwardSupportsParallelMode) { + auto shared_meta = std::make_shared(); + std::vector roots; + roots.reserve(8); + for (int i = 0; i < 8; ++i) { + roots.emplace_back( + "n" + std::to_string(i), + "", + "uid_" + std::to_string(i), + nullptr, + lazyllm::DocNode::Metadata(), + shared_meta); + } + + std::vector root_ptrs; + root_ptrs.reserve(roots.size()); + for (auto& root : roots) root_ptrs.push_back(&root); + + SingleTransform transform(3); + const auto children = transform.batch_forward(root_ptrs, "parallel"); + + EXPECT_EQ(transform.worker_num(), 3); + EXPECT_EQ(children.size(), roots.size()); + for (const auto& root : roots) { + EXPECT_TRUE(root.is_children_group_exists("parallel")); + EXPECT_EQ(root.py_get_children().at("parallel").size(), 1u); + } +} diff --git a/csrc/tests/test_sentence_splitter.cpp b/csrc/tests/test_sentence_splitter.cpp new file mode 100644 index 000000000..38f6297c2 --- /dev/null +++ b/csrc/tests/test_sentence_splitter.cpp @@ -0,0 +1,52 @@ +#include + +#include +#include + +#include "sentence_splitter.hpp" +#include "utils.hpp" + +namespace { + +class TestSentenceSplitter final : public lazyllm::SentenceSplitter { +public: + TestSentenceSplitter(unsigned chunk_size, unsigned overlap) + : lazyllm::SentenceSplitter(chunk_size, overlap, 0) {} + + using lazyllm::SentenceSplitter::merge_chunks; +}; + +} // namespace + +TEST(SentenceSplitter, MergeChunksAppliesOverlapAndTrim) { + TestSentenceSplitter splitter(5, 2); + + const std::vector splits{ + {"ab", false, 2}, + {"cd", false, 2}, + {"ef ", false, 2}, + {" ", false, 2}, + }; + + const auto merged = splitter.merge_chunks(splits, 5); + EXPECT_EQ(merged, (std::vector{"abcd", "cdef", "ef"})); +} + +TEST(SentenceSplitter, MergeChunksThrowsOnOversizedSingleSplit) { + TestSentenceSplitter splitter(3, 1); + const std::vector splits{ + {"abcd", false, 4}, + }; + + EXPECT_THROW((void)splitter.merge_chunks(splits, 3), std::runtime_error); +} + +TEST(SentenceSplitter, MergeChunksDropsWhitespaceOnlyChunks) { + TestSentenceSplitter splitter(8, 2); + const std::vector splits{ + {" ", false, 3}, + }; + + const auto merged = splitter.merge_chunks(splits, 8); + EXPECT_TRUE(merged.empty()); +} diff --git a/csrc/tests/test_text_splitter_base.cpp b/csrc/tests/test_text_splitter_base.cpp new file mode 100644 index 000000000..117a3a0ca --- /dev/null +++ b/csrc/tests/test_text_splitter_base.cpp @@ -0,0 +1,111 @@ +#include + +#include +#include +#include +#include +#include + +#include "doc_node.hpp" +#include "text_splitter_base.hpp" +#include "utils.hpp" + +namespace { + +class ByteTokenizer final : public Tokenizer { +public: + std::vector encode(const std::string_view& view) const override { + std::vector out; + out.reserve(view.size()); + for (unsigned char ch : view) out.push_back(static_cast(ch)); + return out; + } + + std::string decode(const std::vector& token_ids) const override { + std::string out; + out.reserve(token_ids.size()); + for (int token_id : token_ids) out.push_back(static_cast(token_id)); + return out; + } +}; + +class TestTextSplitter final : public lazyllm::TextSplitterBase { +public: + TestTextSplitter(unsigned chunk_size, unsigned overlap) + : lazyllm::TextSplitterBase(chunk_size, overlap, 0) {} + + using lazyllm::TextSplitterBase::merge_chunks; + using lazyllm::TextSplitterBase::split_recursive; +}; + +} // namespace + +TEST(TextSplitterBase, ConstructorValidatesParameters) { + EXPECT_THROW((void)TestTextSplitter(10, 11), std::runtime_error); + EXPECT_THROW((void)TestTextSplitter(0, 0), std::runtime_error); +} + +TEST(TextSplitterBase, SplitTextKeepingSeparator) { + const auto parts = lazyllm::TextSplitterBase::split_text_while_keeping_separator("a--b--", "--"); + ASSERT_EQ(parts.size(), 2u); + EXPECT_EQ(parts[0], "a--"); + EXPECT_EQ(parts[1], "b--"); + + const auto leading_sep = lazyllm::TextSplitterBase::split_text_while_keeping_separator("--x", "--"); + ASSERT_EQ(leading_sep.size(), 1u); + EXPECT_EQ(leading_sep[0], "x"); +} + +TEST(TextSplitterBase, SplitTextChecksMetadataBudget) { + TestTextSplitter splitter(60, 0); + splitter.set_tokenizer(std::make_shared()); + + EXPECT_THROW((void)splitter.split_text("abc", 60), std::runtime_error); + EXPECT_THROW((void)splitter.split_text("abc", 11), std::runtime_error); +} + +TEST(TextSplitterBase, SplitRecursiveFallsBackToCharLevel) { + TestTextSplitter splitter(100, 0); + splitter.set_tokenizer(std::make_shared()); + + const auto chunks = splitter.split_recursive("abc", 2); + ASSERT_EQ(chunks.size(), 3u); + EXPECT_EQ(chunks[0].view, "a"); + EXPECT_EQ(chunks[1].view, "b"); + EXPECT_EQ(chunks[2].view, "c"); + EXPECT_FALSE(chunks[0].is_sentence); +} + +TEST(TextSplitterBase, MergeChunksUsesOverlap) { + TestTextSplitter splitter(100, 1); + splitter.set_tokenizer(std::make_shared()); + + const std::vector splits{ + {"ab", true, 2}, + {"cd", true, 2}, + {"ef", true, 2}, + }; + + const auto merged = splitter.merge_chunks(splits, 4); + EXPECT_EQ(merged, (std::vector{"ab", "bcd", "def"})); +} + +TEST(TextSplitterBase, TransformReturnsChunkNodesAndSupportsNull) { + TestTextSplitter splitter(100, 0); + splitter.set_tokenizer(std::make_shared()); + + auto global_meta = std::make_shared(); + lazyllm::DocNode node("hello", "", "", nullptr, {}, global_meta); + + const auto chunks = splitter.transform(&node); + ASSERT_EQ(chunks.size(), 1u); + EXPECT_EQ(chunks[0].get_text(lazyllm::MetadataMode::NONE), "hello"); + + const auto empty = splitter.transform(nullptr); + EXPECT_TRUE(empty.empty()); +} + +TEST(TextSplitterBase, FromTiktokenEncoderThrowsOnInvalidName) { + TestTextSplitter splitter(100, 0); + EXPECT_THROW((void)splitter.from_tiktoken_encoder("definitely_unknown"), std::runtime_error); +} diff --git a/csrc/tests/test_thread_pool.cpp b/csrc/tests/test_thread_pool.cpp new file mode 100644 index 000000000..6f51bbc16 --- /dev/null +++ b/csrc/tests/test_thread_pool.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include + +#include "thread_pool.hpp" + +TEST(ThreadPool, ExecutesTasksAndReturnsValues) { + ThreadPool pool(3); + + auto f1 = pool.enqueue([] { return 1 + 2; }); + auto f2 = pool.enqueue([](int v) { return v * 2; }, 5); + + EXPECT_EQ(f1.get(), 3); + EXPECT_EQ(f2.get(), 10); +} + +TEST(ThreadPool, PropagatesTaskExceptionThroughFuture) { + ThreadPool pool(1); + auto failing = pool.enqueue([]() -> int { + throw std::runtime_error("boom"); + }); + + EXPECT_THROW((void)failing.get(), std::runtime_error); +} diff --git a/csrc/tests/test_tokenizer.cpp b/csrc/tests/test_tokenizer.cpp new file mode 100644 index 000000000..e2a5e0e3e --- /dev/null +++ b/csrc/tests/test_tokenizer.cpp @@ -0,0 +1,56 @@ +#include + +#include +#include +#include + +#include "tokenizer.hpp" + +namespace { + +class IdentityTokenizer final : public Tokenizer { +public: + std::vector encode(const std::string_view& view) const override { + std::vector out; + out.reserve(view.size()); + for (unsigned char ch : view) out.push_back(static_cast(ch)); + return out; + } + + std::string decode(const std::vector& token_ids) const override { + std::string out; + out.reserve(token_ids.size()); + for (int id : token_ids) out.push_back(static_cast(id)); + return out; + } +}; + +} // namespace + +TEST(Tokenizer, AbstractInterfaceViaDerivedClass) { + std::unique_ptr tokenizer = std::make_unique(); + const auto ids = tokenizer->encode("abc"); + + EXPECT_EQ(ids, (std::vector{97, 98, 99})); + EXPECT_EQ(tokenizer->decode(ids), "abc"); +} + +TEST(TiktokenTokenizer, RoundTripEncoding) { + TiktokenTokenizer tokenizer("gpt2"); + const std::string text = "hello tokenizer"; + + const auto token_ids = tokenizer.encode(text); + EXPECT_FALSE(token_ids.empty()); + EXPECT_EQ(tokenizer.decode(token_ids), text); +} + +TEST(TiktokenTokenizer, AliasNamesMapToSameEncoding) { + TiktokenTokenizer gpt2("gpt2"); + TiktokenTokenizer r50k("r50k_base"); + + EXPECT_EQ(gpt2.encode("same input"), r50k.encode("same input")); +} + +TEST(TiktokenTokenizer, UnknownEncodingThrows) { + EXPECT_THROW((void)TiktokenTokenizer("unknown_model"), std::runtime_error); +} diff --git a/csrc/tests/test_unicode_processor.cpp b/csrc/tests/test_unicode_processor.cpp new file mode 100644 index 000000000..e101fb681 --- /dev/null +++ b/csrc/tests/test_unicode_processor.cpp @@ -0,0 +1,41 @@ +#include + +#include +#include + +#include "unicode_processor.hpp" + +namespace { + +std::vector ToStrings(const std::vector& views) { + std::vector out; + out.reserve(views.size()); + for (const auto& view : views) out.emplace_back(view); + return out; +} + +} // namespace + +TEST(UnicodeProcessor, SplitToCharsSupportsMultibyte) { + const std::string text = "a你🙂"; + const lazyllm::UnicodeProcessor processor(text); + + const auto chars = ToStrings(processor.split_to_chars()); + EXPECT_EQ(chars, (std::vector{"a", "你", "🙂"})); +} + +TEST(UnicodeProcessor, SplitByPunctuationHandlesAsciiAndCjk) { + const std::string text = "Hello,world。你好!"; + const lazyllm::UnicodeProcessor processor(text); + + const auto chunks = ToStrings(processor.split_by_punctuation()); + EXPECT_EQ(chunks, (std::vector{"Hello,", "world。", "你好!"})); +} + +TEST(UnicodeProcessor, SplitByPunctuationFallbackWhenOnlyPunctuation) { + const std::string text = "!!!"; + const lazyllm::UnicodeProcessor processor(text); + + const auto chunks = ToStrings(processor.split_by_punctuation()); + EXPECT_EQ(chunks, (std::vector{"!!!"})); +} diff --git a/csrc/tests/test_utils.cpp b/csrc/tests/test_utils.cpp new file mode 100644 index 000000000..008510ef5 --- /dev/null +++ b/csrc/tests/test_utils.cpp @@ -0,0 +1,55 @@ +#include + +#include +#include +#include +#include + +#include "utils.hpp" + +TEST(Utils, JoinLinesAndConcatVector) { + EXPECT_EQ(lazyllm::JoinLines({}), ""); + EXPECT_EQ(lazyllm::JoinLines({"a", "b", "c"}), "a\nb\nc"); + EXPECT_EQ(lazyllm::JoinLines({"a", "b", "c"}, ','), "a,b,c"); + + const auto merged = lazyllm::ConcatVector(std::vector{1, 2}, std::vector{3, 4}); + EXPECT_EQ(merged, (std::vector{1, 2, 3, 4})); +} + +TEST(Utils, SetUnionAndSetDiff) { + const std::set left{1, 2, 3}; + const std::set right{3, 4}; + + EXPECT_EQ(lazyllm::SetUnion(left, right), (std::set{1, 2, 3, 4})); + EXPECT_EQ(lazyllm::SetDiff(left, right), (std::set{1, 2})); +} + +TEST(Utils, HexUuidAndAdjacency) { + EXPECT_EQ(lazyllm::to_hex(255u), "ff"); + + const std::string uuid = lazyllm::GenerateUUID(); + const std::regex pattern("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"); + EXPECT_TRUE(std::regex_match(uuid, pattern)); + + const std::string text = "abcdef"; + const std::string_view left = std::string_view(text.data(), 3); + const std::string_view right_adjacent = std::string_view(text.data() + 3, 3); + const std::string_view right_not_adjacent = std::string_view(text.data() + 4, 2); + EXPECT_TRUE(lazyllm::is_adjacent(left, right_adjacent)); + EXPECT_FALSE(lazyllm::is_adjacent(left, right_not_adjacent)); +} + +TEST(Utils, ChunkOperatorAccumulatesFields) { + lazyllm::Chunk l{"ab", true, 2}; + lazyllm::Chunk r{"cd", false, 3}; + + l += r; + EXPECT_EQ(l.text, "abcd"); + EXPECT_FALSE(l.is_sentence); + EXPECT_EQ(l.token_size, 5); +} + +TEST(Utils, MetadataKeyConstantsExposed) { + EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_PATH, "lazyllm_doc_path"); + EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_ID, "docid"); +} From a214e355bbb8055d972d57863aef33de67969487 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 27 Feb 2026 15:16:57 +0800 Subject: [PATCH 29/40] libstdc++.so.6 --- csrc/cmake/tests.cmake | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/csrc/cmake/tests.cmake b/csrc/cmake/tests.cmake index 88803297a..5f72a0890 100644 --- a/csrc/cmake/tests.cmake +++ b/csrc/cmake/tests.cmake @@ -10,6 +10,24 @@ FetchContent_MakeAvailable(googletest) enable_testing() include(GoogleTest) +# Resolve libstdc++ from the active C++ compiler. This avoids loading an +# older copy from a preloaded environment (for example Conda) at test runtime. +execute_process( + COMMAND ${CMAKE_CXX_COMPILER} -print-file-name=libstdc++.so.6 + OUTPUT_VARIABLE LIBSTDCPP_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +set(TEST_RUNTIME_ENV "") +if (LIBSTDCPP_PATH AND NOT LIBSTDCPP_PATH STREQUAL "libstdc++.so.6") + # A bare "libstdc++.so.6" means the compiler did not return a concrete path. + get_filename_component(LIBSTDCPP_DIR "${LIBSTDCPP_PATH}" DIRECTORY) + if (LIBSTDCPP_DIR) + # Prepend compiler runtime directory so ctest picks the matching ABI first. + set(TEST_RUNTIME_ENV "LD_LIBRARY_PATH=${LIBSTDCPP_DIR}:$ENV{LD_LIBRARY_PATH}") + endif () +endif () + file(GLOB_RECURSE LAZYLLM_TEST_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/tests/*.cpp" ) @@ -27,5 +45,15 @@ foreach (test_src ${LAZYLLM_TEST_SOURCES}) pybind11::headers Python3::Python ) - gtest_discover_tests(${test_name}) + gtest_add_tests( + TARGET ${test_name} + TEST_LIST discovered_tests + ) + + # Attach runtime env per discovered case so each test gets the same loader path. + if (TEST_RUNTIME_ENV AND discovered_tests) + set_tests_properties(${discovered_tests} PROPERTIES + ENVIRONMENT "${TEST_RUNTIME_ENV}" + ) + endif () endforeach () From e865ab6a7d9e313e30b1efa46158aef92d6f355e Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 27 Feb 2026 15:17:44 +0800 Subject: [PATCH 30/40] DocNode manage itself. --- csrc/binding/export_doc_node.cpp | 12 ++--- csrc/binding/export_node_transform.cpp | 33 +++++------- csrc/binding/export_text_splitter_base.cpp | 1 + csrc/core/include/doc_node.hpp | 21 +++++--- csrc/core/include/map_params.hpp | 2 + csrc/core/include/node_transform.hpp | 59 ++-------------------- csrc/core/src/node_transform.cpp | 45 +++++++++++++++++ csrc/scripts/build_test.sh | 2 +- 8 files changed, 87 insertions(+), 88 deletions(-) create mode 100644 csrc/core/src/node_transform.cpp diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 8370ad428..9270ef489 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -87,10 +87,10 @@ lazyllm::DocNode init( std::string DocNodeToString(const lazyllm::DocNode& node) { py::dict d; - const auto children = node.py_get_children(); + const auto children = node.get_children(); for (const auto& [group, nodes] : children) { py::list ids; - for (const auto* n : nodes) { + for (std::shared_ptr n : nodes) { if (n) ids.append(n->get_uid()); } d[py::str(group)] = std::move(ids); @@ -110,7 +110,7 @@ void exportDocNode(py::module& m) { .value("LLM", lazyllm::MetadataMode::LLM) .value("NONE", lazyllm::MetadataMode::NONE); - py::class_(m, "DocNode") + py::class_>(m, "DocNode") .def(py::init(&init), py::kw_only(), py::arg("uid") = py::none(), @@ -167,7 +167,7 @@ void exportDocNode(py::module& m) { py::return_value_policy::reference ) .def_property("children", - [](const lazyllm::DocNode& node) { return node.py_get_children(); }, + [](const lazyllm::DocNode& node) { return node.get_children(); }, [](lazyllm::DocNode& node, const lazyllm::DocNode::Children& children) { node.set_children(children); } @@ -221,10 +221,10 @@ void exportDocNode(py::module& m) { ) .def("get_children_str", [](const lazyllm::DocNode& node) { py::dict d; - const auto children = node.py_get_children(); + const auto children = node.get_children(); for (const auto& [group, nodes] : children) { py::list ids; - for (const auto* n : nodes) if (n) ids.append(n->get_uid()); + for (std::shared_ptr n : nodes) if (n) ids.append(n->get_uid()); d[py::str(group)] = std::move(ids); } return py::str(d); diff --git a/csrc/binding/export_node_transform.cpp b/csrc/binding/export_node_transform.cpp index c67f185fd..99ed61eb1 100644 --- a/csrc/binding/export_node_transform.cpp +++ b/csrc/binding/export_node_transform.cpp @@ -12,28 +12,21 @@ class PyNodeTransform : public lazyllm::NodeTransform { public: using lazyllm::NodeTransform::NodeTransform; - std::vector transform(const lazyllm::DocNode* document) const override { + std::vector transform(lazyllm::PDocNode node) const override { py::gil_scoped_acquire gil; py::function overload = py::get_override(static_cast(this), "transform"); if (!overload) throw std::runtime_error("NodeTransform.transform is not implemented."); - py::object result = overload(document); + py::object result = overload(node); if (!py::isinstance(result)) { throw std::runtime_error("NodeTransform.transform must return a sequence."); } - std::vector out; + std::vector out; for (auto item : result) { py::object obj = py::reinterpret_borrow(item); if (obj.is_none()) continue; - - if (py::isinstance(obj)) { - std::string text = obj.cast(); - if (text.empty()) continue; - out.emplace_back(std::move(text)); - } else { - out.emplace_back(obj.cast()); - } + out.emplace_back(obj.cast()); } return out; } @@ -50,11 +43,11 @@ void exportNodeTransform(py::module& m) { const std::string& node_group, py::object /*ref_path*/, py::kwargs /*kwargs*/) { - std::vector docs; + std::vector docs; if (py::isinstance(documents)) { - for (auto item : documents) docs.push_back(py::cast(item)); + for (auto item : documents) docs.push_back(py::cast(item)); } else - docs.push_back(documents.cast()); + docs.push_back(documents.cast()); return self.batch_forward(docs, node_group); }, py::arg("documents"), @@ -63,16 +56,16 @@ void exportNodeTransform(py::module& m) { py::return_value_policy::reference ) .def("transform", - [](const lazyllm::NodeTransform& self, lazyllm::DocNode* document, py::kwargs /*kwargs*/) { - if (document == nullptr) return std::vector{}; - return self.transform(document); + [](const lazyllm::NodeTransform& self, lazyllm::PDocNode node, py::kwargs /*kwargs*/) { + if (node == nullptr) return std::vector{}; + return self.transform(node); }, py::arg("document") ) .def("__call__", - [](const lazyllm::NodeTransform& self, lazyllm::DocNode* node, py::kwargs /*kwargs*/) { - if (node == nullptr) return std::vector{}; - return self(*node); + [](const lazyllm::NodeTransform& self, lazyllm::PDocNode node, py::kwargs /*kwargs*/) { + if (node == nullptr) return std::vector{}; + return self(node); }, py::arg("node") ) diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp index ea99b3d11..1b3f05d73 100644 --- a/csrc/binding/export_text_splitter_base.cpp +++ b/csrc/binding/export_text_splitter_base.cpp @@ -198,6 +198,7 @@ void exportTextSpliterBase(py::module& m) { } ) .def_static("get_default", + // TODO: too verbose [](py::object param_name) -> py::object { const auto defaults = lazyllm::TextSplitterBase::_default_params.get_default(); if (param_name.is_none()) { diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 3380c0292..7d9941496 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -19,11 +19,12 @@ namespace lazyllm { enum class MetadataMode { ALL, EMBED, LLM, NONE }; -// TODO: Refactor docnode management from NodeTransform to "Parant" DocNode +using PDocNode = PDocNode; + class DocNode { public: using Metadata = std::unordered_map; - using Children = std::unordered_map>; + using Children = std::unordered_map>; using EmbeddingFun = std::function(const std::string&, const std::string&)>; using EmbeddingVecs = std::unordered_map>; @@ -50,7 +51,6 @@ class DocNode { std::shared_ptr _p_store = nullptr; public: - DocNode() = delete; explicit DocNode( const std::string_view& text_view, const std::string& group_name = "", @@ -67,6 +67,11 @@ class DocNode { { set_text_view(text_view); } + DocNode() : DocNode("", "") {} + explicit DocNode(std::string&& text, const std::shared_ptr& global_metadata = {}) + : DocNode(text, "", "", nullptr, {}, global_metadata) { + set_root_text(std::move(text)); + } DocNode(const DocNode&) = default; DocNode& operator=(const DocNode&) = default; @@ -122,14 +127,16 @@ class DocNode { size_t get_text_hash() const { return _text_hash; } const DocNode* get_parent_node() const { return _p_parent_node; } void set_parent_node(const DocNode* p_parent_node) { _p_parent_node = p_parent_node; } - Children py_get_children() const { + Children get_children() const { if (!_children.empty()) return _children; if (_p_store == nullptr) return Children(); _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); return _children; } void set_children(const Children& children) { _children = children; } - void set_children_group(const std::string& group_name, const std::vector& children_group) { + void set_children_group( + const std::string& group_name, + const std::vector& children_group) { _children[group_name] = children_group; } std::set get_excluded_embed_metadata_keys() const { @@ -152,8 +159,8 @@ class DocNode { void set_doc_path(const std::string& path) { get_root_node()->_p_global_metadata->operator[](std::string(RAGMetadataKeys::DOC_PATH)) = path; } - auto py_get_children_uid() const { - auto children = py_get_children(); + auto get_children_uid() const { + auto children = get_children(); std::unordered_map> children_uid; for (auto& [group_name, nodes] : children) { children_uid[group_name] = {}; diff --git a/csrc/core/include/map_params.hpp b/csrc/core/include/map_params.hpp index 0659dca93..7b75cebba 100644 --- a/csrc/core/include/map_params.hpp +++ b/csrc/core/include/map_params.hpp @@ -23,6 +23,8 @@ class MapParams { std::lock_guard guard(_lock); auto it = _params.find(std::string(param_name)); if (it != _params.end()) return std::any_cast(it->second); + + // Don't set default value here to avoid parameter pollution. return default_value; } diff --git a/csrc/core/include/node_transform.hpp b/csrc/core/include/node_transform.hpp index e4f1f9b9a..c31650374 100644 --- a/csrc/core/include/node_transform.hpp +++ b/csrc/core/include/node_transform.hpp @@ -24,66 +24,17 @@ class NodeTransform { explicit NodeTransform(int worker_num = 0) : _worker_num(worker_num) {} virtual ~NodeTransform() = default; - virtual std::vector transform(const DocNode*) const = 0; - std::vector operator()(const DocNode& node) const { return transform(&node); } - - std::vector batch_forward(std::vector& nodes, const std::string& node_group_name) { - std::vector whole_nodes; - if (nodes.empty()) return whole_nodes; - - if (_worker_num > 0) { - ThreadPool pool(static_cast(_worker_num)); - std::vector>> futures; - futures.reserve(nodes.size()); - for (auto* p_node : nodes) { - futures.emplace_back(pool.enqueue( - [this, p_node, node_group_name] { return forward(p_node, node_group_name); })); - } - - for (auto& fut : futures) { - auto parts = fut.get(); - whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); - } - } else { - for (auto* p_node : nodes) { - auto parts = forward(p_node, node_group_name); - whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); - } - } - - return whole_nodes; - } + virtual std::vector transform(const PDocNode) const = 0; + std::vector operator()(const PDocNode node) const { return transform(node); } + std::vector batch_forward(std::vector& nodes, const std::string& node_group_name); int worker_num() const { return _worker_num; } private: - std::vector forward(DocNode* p_node, const std::string& node_group_name) { - if (p_node->is_children_group_exists(node_group_name)) return {}; - - auto raw_nodes = transform(p_node); - std::vector out; - out.reserve(raw_nodes.size()); - - for (auto& node_ : raw_nodes) { - node_.set_parent_node(p_node); - node_._group_name = node_group_name; - auto child = std::make_unique(std::move(node_)); - auto* ptr = child.get(); - { - std::lock_guard lock(_owned_nodes_mutex); - _owned_nodes.emplace_back(std::move(child)); - } - out.push_back(ptr); - } - p_node->set_children_group(node_group_name, out); - - return out; - } + std::vector forward(PDocNode node, const std::string& node_group_name); - std::vector> _owned_nodes; - std::mutex _owned_nodes_mutex; + std::mutex _lock; int _worker_num = 0; - bool _support_rich = false; }; } // namespace lazyllm diff --git a/csrc/core/src/node_transform.cpp b/csrc/core/src/node_transform.cpp new file mode 100644 index 000000000..8d69eea9b --- /dev/null +++ b/csrc/core/src/node_transform.cpp @@ -0,0 +1,45 @@ +#include "node_transform.hpp" + +namespace lazyllm { + +std::vector NodeTransform::batch_forward( + std::vector& nodes, const std::string& node_group_name +) { + std::vector whole_nodes; + if (nodes.empty()) return whole_nodes; + + if (_worker_num > 0) { + ThreadPool pool(static_cast(_worker_num)); + std::vector>> futures; + futures.reserve(nodes.size()); + for (auto node : nodes) { + futures.emplace_back(pool.enqueue( + [this, node, node_group_name]() { return forward(node, node_group_name); })); + } + + for (auto& fut : futures) { + auto parts = fut.get(); + whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); + } + } else { + for (auto node : nodes) { + auto parts = forward(node, node_group_name); + whole_nodes.insert(whole_nodes.end(), parts.begin(), parts.end()); + } + } + + return whole_nodes; +} + +std::vector NodeTransform::forward(PDocNode node, const std::string& node_group_name) { + auto p_nodes = transform(node); + for (auto& p_node : p_nodes) { + p_node->set_parent_node(&*node); + p_node->_group_name = node_group_name; + } + node->set_children_group(node_group_name, p_nodes); + + return p_nodes; +} + +} // namespace lazyllm diff --git a/csrc/scripts/build_test.sh b/csrc/scripts/build_test.sh index cd9336823..c6b14f9c3 100644 --- a/csrc/scripts/build_test.sh +++ b/csrc/scripts/build_test.sh @@ -6,4 +6,4 @@ cmake -S csrc -B build \ -DCMAKE_BUILD_TYPE=Debug \ -DBUILD_TESTS=ON cmake --build build -ctest --test-dir build +ctest --test-dir build --rerun-failed --output-on-failure From 2fd858355afaf57d22e7a6104e932e5bcfa909f3 Mon Sep 17 00:00:00 2001 From: yzh Date: Mon, 2 Mar 2026 19:25:53 +0800 Subject: [PATCH 31/40] finish cpp side tests --- csrc/core/include/doc_node.hpp | 4 +- csrc/core/include/sentence_splitter.hpp | 11 +-- csrc/core/include/text_splitter_base.hpp | 29 +++--- csrc/core/src/sentence_splitter.cpp | 108 ++++++++------------- csrc/core/src/text_splitter_base.cpp | 1 + csrc/core/src/unicode_processor.cpp | 19 +++- csrc/tests/test_adaptor_base.cpp | 12 ++- csrc/tests/test_doc_node.cpp | 115 ++++++++++++----------- csrc/tests/test_map_params.cpp | 30 ++++-- csrc/tests/test_node_transform.cpp | 114 +++++++--------------- csrc/tests/test_sentence_splitter.cpp | 24 +++-- csrc/tests/test_smoke.cpp | 2 +- csrc/tests/test_text_splitter_base.cpp | 66 +++++++------ csrc/tests/test_thread_pool.cpp | 11 ++- csrc/tests/test_tokenizer.cpp | 8 +- csrc/tests/test_unicode_processor.cpp | 33 +++---- csrc/tests/test_utils.cpp | 41 ++++++-- 17 files changed, 300 insertions(+), 328 deletions(-) diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 7d9941496..8f7349b0f 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -19,7 +19,7 @@ namespace lazyllm { enum class MetadataMode { ALL, EMBED, LLM, NONE }; -using PDocNode = PDocNode; +using PDocNode = std::shared_ptr; class DocNode { public: @@ -119,7 +119,7 @@ class DocNode { const auto& metadata_string = get_metadata_string(mode); return metadata_string + "\n\n" + std::string(_text_view); } - void set_root_text(const std::string&& text) { + void set_root_text(std::string&& text) { _p_root_text = std::make_shared(std::move(text)); set_text_view(*_p_root_text); } diff --git a/csrc/core/include/sentence_splitter.hpp b/csrc/core/include/sentence_splitter.hpp index 25f639af5..8963cdd9b 100644 --- a/csrc/core/include/sentence_splitter.hpp +++ b/csrc/core/include/sentence_splitter.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -20,16 +21,6 @@ class SentenceSplitter : public TextSplitterBase { protected: std::vector merge_chunks(const std::vector& splits, int chunk_size) const override; - -private: - void close_chunk( - std::vector& chunks, - std::vector& cur_chunk, - int& cur_chunk_len, - bool& is_chunk_new) const; - - static std::string trim_ascii(std::string_view input); - static std::string join_parts(const std::vector& parts); }; } // namespace lazyllm diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 2988a9d1d..501b63005 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -27,30 +27,26 @@ class TextSplitterBase : public NodeTransform { static MapParams _default_params; explicit TextSplitterBase( - std::optional chunk_size, - std::optional overlap, - std::optional worker_num, + std::optional chunk_size = std::nullopt, + std::optional overlap = std::nullopt, + std::optional worker_num = std::nullopt, const std::string& encoding_name = "gpt2") : NodeTransform(_default_params.get_param_value("worker_num", worker_num, 0)), _chunk_size(_default_params.get_param_value("chunk_size", chunk_size, 1024)), _overlap(_default_params.get_param_value("overlap", overlap, 200)) { - if (_overlap > _chunk_size) throw std::runtime_error("'overlap' should be less than 'chunk_size'."); + if (_overlap >= _chunk_size) throw std::runtime_error("'overlap' should be less than 'chunk_size'."); if (_chunk_size == 0) throw std::runtime_error("'chunk_size' should > 0"); _tokenizer = std::make_shared(encoding_name); } - std::vector transform(const DocNode* node) const override { - if (node == nullptr) return {}; - auto chunks = split_text(node->get_text_view(), get_node_metadata_size(node)); - std::vector nodes; + std::vector transform(PDocNode node) const override { + auto chunks = split_text(node->get_text_view(), get_node_metadata_size(*node)); + std::vector nodes; nodes.reserve(chunks.size()); - for (const auto& chunk : chunks) { - DocNode chunk_node(chunk); - chunk_node.set_root_text(std::string(chunk)); - nodes.emplace_back(std::move(chunk_node)); - } + for (auto& chunk : chunks) + nodes.push_back(std::make_shared(std::move(chunk))); return nodes; } @@ -82,13 +78,14 @@ class TextSplitterBase : public NodeTransform { private: std::tuple, bool> split_by_functions(const std::string_view& text) const; - int get_node_metadata_size(const DocNode* node) const { + int get_node_metadata_size(const DocNode& node) const { return std::max( - get_token_size(node->get_metadata_string(MetadataMode::EMBED)), - get_token_size(node->get_metadata_string(MetadataMode::LLM))); + get_token_size(node.get_metadata_string(MetadataMode::EMBED)), + get_token_size(node.get_metadata_string(MetadataMode::LLM))); } int get_token_size(const std::string_view& view) const { + if (view.empty()) return 0; return static_cast(_tokenizer->encode(view).size()); } diff --git a/csrc/core/src/sentence_splitter.cpp b/csrc/core/src/sentence_splitter.cpp index 926f2685b..3b9931903 100644 --- a/csrc/core/src/sentence_splitter.cpp +++ b/csrc/core/src/sentence_splitter.cpp @@ -2,91 +2,57 @@ #include #include -#include -#include -#include -#include namespace lazyllm { -std::string SentenceSplitter::trim_ascii(std::string_view input) { - size_t start = 0; - while (start < input.size() && std::isspace(static_cast(input[start]))) ++start; - - size_t end = input.size(); - while (end > start && std::isspace(static_cast(input[end - 1]))) --end; - - return std::string(input.substr(start, end - start)); -} - -std::string SentenceSplitter::join_parts(const std::vector& parts) { - size_t total_len = 0; - for (const auto& part : parts) total_len += part.text.size(); - +std::string join_views( + size_t string_size, + std::vector::const_iterator begin, + const std::vector::const_iterator& end +) { std::string out; - out.reserve(total_len); - for (const auto& part : parts) out += part.text; - return out; -} - -void SentenceSplitter::close_chunk( - std::vector& chunks, - std::vector& cur_chunk, - int& cur_chunk_len, - bool& is_chunk_new) const -{ - chunks.push_back(join_parts(cur_chunk)); - auto last_chunk = std::move(cur_chunk); - cur_chunk.clear(); - cur_chunk_len = 0; - is_chunk_new = true; - - int overlap_len = 0; - for (auto it = last_chunk.rbegin(); it != last_chunk.rend(); ++it) { - if (overlap_len + it->token_size > _overlap) break; - cur_chunk.push_back(*it); - overlap_len += it->token_size; - cur_chunk_len += it->token_size; + out.reserve(string_size); + while(begin != end) { + out.append(begin->view); + ++begin; } - std::reverse(cur_chunk.begin(), cur_chunk.end()); + return out; } -std::vector SentenceSplitter::merge_chunks(const std::vector& splits, int chunk_size) const { - std::vector chunks; - std::vector cur_chunk; - int cur_chunk_len = 0; - bool is_chunk_new = true; +std::vector SentenceSplitter::merge_chunks(const std::vector& chunks, int chunk_size) const { + std::vector out; - size_t i = 0; - while (i < splits.size()) { - const auto& cur_split = splits[i]; - if (cur_split.token_size > chunk_size) { - throw std::runtime_error("Single token exceeded chunk size"); + auto iLeft = chunks.begin(); + auto iRight = chunks.begin(); + const auto& iEnd = chunks.end(); + int window_token_sum = 0; + size_t string_size = 0; + + while (iRight != iEnd) { + if (iRight->token_size > chunk_size) + throw std::runtime_error("Chunk size is too big."); + + // Grow right edge to the largest window under chunk_size. + while (iRight != iEnd && window_token_sum + iRight->token_size <= chunk_size) { + window_token_sum += iRight->token_size; + string_size += iRight->view.size(); + ++iRight; } - if (cur_chunk_len + cur_split.token_size > chunk_size && !is_chunk_new) { - close_chunk(chunks, cur_chunk, cur_chunk_len, is_chunk_new); - continue; - } + // Merge chunks witin window. + out.push_back(join_views(string_size, iLeft, iRight)); - if (cur_split.is_sentence || cur_chunk_len + cur_split.token_size <= chunk_size || is_chunk_new) { - cur_chunk_len += cur_split.token_size; - cur_chunk.push_back({std::string(cur_split.view), cur_split.is_sentence, cur_split.token_size}); - ++i; - is_chunk_new = false; - } else { - close_chunk(chunks, cur_chunk, cur_chunk_len, is_chunk_new); + // Shrink left edge to select overlap of next merge. + while (iRight != iEnd && iLeft != iRight && ( + window_token_sum > _overlap || window_token_sum + iRight->token_size > chunk_size + )) { + window_token_sum -= iLeft->token_size; + string_size -= iLeft->view.size(); + ++iLeft; } + // Now window contains only overlap. } - if (!is_chunk_new) chunks.push_back(join_parts(cur_chunk)); - - std::vector out; - out.reserve(chunks.size()); - for (const auto& chunk : chunks) { - auto stripped = trim_ascii(chunk); - if (!stripped.empty()) out.push_back(std::move(stripped)); - } return out; } diff --git a/csrc/core/src/text_splitter_base.cpp b/csrc/core/src/text_splitter_base.cpp index 00361edcb..1ae75877d 100644 --- a/csrc/core/src/text_splitter_base.cpp +++ b/csrc/core/src/text_splitter_base.cpp @@ -58,6 +58,7 @@ std::vector TextSplitterBase::split_recursive( std::vector splits; for (const auto& segment_view : views) { const int seg_token_size = get_token_size(segment_view); + if (seg_token_size == 0) continue; if (seg_token_size <= chunk_size) { splits.push_back({segment_view, is_sentence, seg_token_size}); } else { diff --git a/csrc/core/src/unicode_processor.cpp b/csrc/core/src/unicode_processor.cpp index 26d80f3b2..1abdb8c1c 100644 --- a/csrc/core/src/unicode_processor.cpp +++ b/csrc/core/src/unicode_processor.cpp @@ -32,6 +32,23 @@ void UnicodeProcessor::for_each_utf8_unit(const Utf8Visitor& visitor) const { } } +/** + * UTF-8 text processing has three distinct layers: + * 1) Byte: the storage unit in std::string_view; one code point uses 1-4 UTF-8 bytes. + * 2) Code point: a Unicode scalar value (for example U+0061, U+4E2D), decoded by utf8proc_iterate. + * 3) Grapheme cluster: one user-perceived character, which may contain multiple code points + * (for example base + combining mark, or emoji + VS/ZWJ sequences). + * + * This function splits by grapheme cluster, not by byte or code point: + * - for_each_utf8_unit() uses utf8proc_iterate to decode UTF-8 and provide + * code point, byte offset, and byte length. + * - utf8proc_grapheme_break_stateful(prev, codepoint, &state) determines whether + * there is a grapheme boundary between prev and the current code point. + * - When a boundary appears, we emit a string_view slice over byte range + * [cluster_start, offset). + * + * This keeps splitting zero-copy (string_view) while following Unicode grapheme-boundary rules. + */ std::vector UnicodeProcessor::split_to_chars() const { std::vector out; if (_text.empty()) return out; @@ -67,7 +84,7 @@ std::vector UnicodeProcessor::split_by_punctuation() const { if (is_sentence_punctuation(codepoint)) { if (chunk_start != std::string_view::npos) { const size_t end = offset + byte_len; - out.emplace_back(_text.substr(chunk_start, end - chunk_start)); + out.push_back(_text.substr(chunk_start, end - chunk_start)); chunk_start = std::string_view::npos; } } else if (chunk_start == std::string_view::npos) { diff --git a/csrc/tests/test_adaptor_base.cpp b/csrc/tests/test_adaptor_base.cpp index a72b9b897..3a7cc10e8 100644 --- a/csrc/tests/test_adaptor_base.cpp +++ b/csrc/tests/test_adaptor_base.cpp @@ -8,7 +8,7 @@ namespace { -class EchoAdaptor final : public lazyllm::AdaptorBase { +class MockAdaptor final : public lazyllm::AdaptorBase { public: mutable int call_count = 0; @@ -26,10 +26,12 @@ class EchoAdaptor final : public lazyllm::AdaptorBase { } // namespace -TEST(AdaptorBase, DerivedCallReceivesArgsAndReturnsAny) { - EchoAdaptor adaptor; - const auto result = adaptor.call("sum", {{"left", 3}, {"right", 4}}); +TEST(adaptor_base, derived_call) { + MockAdaptor adaptor; + auto result = adaptor.call("echo_me", {}); + EXPECT_EQ(std::any_cast(result), "echo_me"); + EXPECT_EQ(adaptor.call_count, 1); + result = adaptor.call("sum", {{"left", 3}, {"right", 4}}); EXPECT_EQ(std::any_cast(result), 7); - EXPECT_EQ(adaptor.call_count, 1); } diff --git a/csrc/tests/test_doc_node.cpp b/csrc/tests/test_doc_node.cpp index d209ba341..7f9dec7fc 100644 --- a/csrc/tests/test_doc_node.cpp +++ b/csrc/tests/test_doc_node.cpp @@ -12,7 +12,7 @@ namespace { -class CountingAdaptor final : public lazyllm::AdaptorBase { +class MockChildrenAdaptor final : public lazyllm::AdaptorBase { public: mutable int call_count = 0; lazyllm::DocNode::Children to_return; @@ -30,68 +30,84 @@ class CountingAdaptor final : public lazyllm::AdaptorBase { } // namespace -TEST(DocNode, ConstructorAndTextHashUpdate) { - auto global_meta = std::make_shared(); - lazyllm::DocNode node("hello", "group", "fixed-uid", nullptr, {}, global_meta); - +TEST(doc_node, constructor) { + lazyllm::DocNode node("hello", "group", "fixed-uid"); EXPECT_EQ(node.get_uid(), "fixed-uid"); EXPECT_EQ(node._group_name, "group"); EXPECT_EQ(node.get_text(), "hello"); +} +TEST(doc_node, set_text_view_updates_text_hash) { + lazyllm::DocNode node("hello", "group", "fixed-uid"); const size_t old_hash = node.get_text_hash(); node.set_text_view("world"); EXPECT_NE(node.get_text_hash(), old_hash); EXPECT_EQ(node.get_text(), "world"); } -TEST(DocNode, MetadataModesAndTextRender) { - auto global_meta = std::make_shared(); - lazyllm::DocNode::Metadata metadata{ +TEST(doc_node, metadata) { + lazyllm::DocNode node; + node._metadata = lazyllm::DocNode::Metadata{ {"alpha", std::string("A")}, {"beta", std::string("B")}, }; - lazyllm::DocNode node("body", "", "", nullptr, metadata, global_meta); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::ALL), "alpha:A\nbeta:B"); node.set_excluded_embed_metadata_keys({"beta"}); - node.set_excluded_llm_metadata_keys({"alpha"}); - - EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::ALL), "alpha:A\nbeta:B"); EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::EMBED), "alpha:A"); + + node.set_excluded_llm_metadata_keys({"alpha"}); EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::LLM), "beta:B"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::NONE), ""); + lazyllm::DocNode root("root"); + lazyllm::DocNode child("child", "", "", &root); + + root.set_excluded_embed_metadata_keys({"root_embed"}); + child.set_excluded_embed_metadata_keys({"child_embed"}); + EXPECT_EQ( + child.get_excluded_embed_metadata_keys(), + (std::set{"child_embed", "root_embed"})); + + root.set_excluded_llm_metadata_keys({"root_llm"}); + child.set_excluded_llm_metadata_keys({"child_llm"}); + EXPECT_EQ( + child.get_excluded_llm_metadata_keys(), + (std::set{"child_llm", "root_llm"})); +} + +TEST(doc_node, text) { + lazyllm::DocNode node("body"); + node._metadata = lazyllm::DocNode::Metadata{{"alpha", std::string("A")}}; EXPECT_EQ(node.get_text(lazyllm::MetadataMode::NONE), "body"); EXPECT_EQ(node.get_text(lazyllm::MetadataMode::EMBED), "alpha:A\n\nbody"); } -TEST(DocNode, ParentChildrenAndUidViews) { - auto global_meta = std::make_shared(); - lazyllm::DocNode root("root", "root_group", "root_uid", nullptr, {}, global_meta); - lazyllm::DocNode child("child", "child_group", "child_uid", &root, {}, global_meta); - - root.set_children_group("split", {&child}); - +TEST(doc_node, relationships) { + lazyllm::DocNode root("root", "root_group", "root_uid"); + lazyllm::DocNode child("child", "child_group", "child_uid", &root); EXPECT_EQ(child.get_root_node(), &root); EXPECT_EQ(child.get_parent_uid(), "root_uid"); - EXPECT_TRUE(root.is_children_group_exists("split")); - const auto child_ids = root.py_get_children_uid(); + root.set_children_group("split", {std::make_shared(std::move(child))}); + EXPECT_TRUE(root.is_children_group_exists("split")); + const auto child_ids = root.get_children_uid(); ASSERT_TRUE(child_ids.find("split") != child_ids.end()); ASSERT_EQ(child_ids.at("split").size(), 1u); EXPECT_EQ(child_ids.at("split")[0], "child_uid"); } -TEST(DocNode, StoreBackedChildrenAreCached) { - auto global_meta = std::make_shared(); - lazyllm::DocNode parent("p", "", "parent", nullptr, {}, global_meta); - lazyllm::DocNode child("c", "", "child", &parent, {}, global_meta); +TEST(doc_node, children_caching) { + lazyllm::DocNode parent("p", "", "parent"); + lazyllm::DocNode child("c", "", "child", &parent); - auto adaptor = std::make_shared(); - adaptor->to_return["cached"] = {&child}; + auto adaptor = std::make_shared(); + adaptor->to_return["cached"] = {std::make_shared(std::move(child))}; parent.set_store(adaptor); - const auto first = parent.py_get_children(); - const auto second = parent.py_get_children(); + const auto first = parent.get_children(); + const auto second = parent.get_children(); EXPECT_EQ(adaptor->call_count, 1); ASSERT_TRUE(first.find("cached") != first.end()); @@ -100,57 +116,46 @@ TEST(DocNode, StoreBackedChildrenAreCached) { EXPECT_EQ(second.at("cached")[0], &child); } -TEST(DocNode, GlobalDocPathAndExclusionInheritance) { +TEST(doc_node, doc_path_reads_from_root_global_metadata) { auto global_meta = std::make_shared(); (*global_meta)[std::string(lazyllm::RAGMetadataKeys::DOC_PATH)] = std::string("/tmp/a.txt"); - lazyllm::DocNode root("root", "", "root", nullptr, {}, global_meta); lazyllm::DocNode child("child", "", "child", &root, {}, global_meta); - root.set_excluded_embed_metadata_keys({"root_embed"}); - child.set_excluded_embed_metadata_keys({"child_embed"}); - root.set_excluded_llm_metadata_keys({"root_llm"}); - child.set_excluded_llm_metadata_keys({"child_llm"}); - EXPECT_EQ(child.get_doc_path(), "/tmp/a.txt"); + child.set_doc_path("/tmp/b.txt"); EXPECT_EQ(root.get_doc_path(), "/tmp/b.txt"); - - EXPECT_EQ( - child.get_excluded_embed_metadata_keys(), - (std::set{"child_embed", "root_embed"})); - EXPECT_EQ( - child.get_excluded_llm_metadata_keys(), - (std::set{"child_llm", "root_llm"})); } -TEST(DocNode, EmbeddingHelpers) { - auto global_meta = std::make_shared(); - lazyllm::DocNode node("text", "", "node", nullptr, {}, global_meta); - +TEST(doc_node, embedding_keys_undone_throws_on_empty_input) { + lazyllm::DocNode node; EXPECT_THROW(node.embedding_keys_undone({}), std::runtime_error); node.set_embedding_vec("done", {1.0, 2.0}); const auto missing = node.embedding_keys_undone({"done", "todo"}); EXPECT_EQ(missing, (std::set{"todo"})); +} + +TEST(doc_node, py_do_embedding_writes_embedding_vector) { + lazyllm::DocNode node("text"); node.py_do_embedding({ {"len_embedding", [](const std::string& input) { return std::vector{static_cast(input.size())}; }} }); + ASSERT_TRUE(node._embedding_vecs.find("len_embedding") != node._embedding_vecs.end()); ASSERT_EQ(node._embedding_vecs["len_embedding"].size(), 1u); - EXPECT_EQ(node._embedding_vecs["len_embedding"][0], 6.0); + EXPECT_EQ(node._embedding_vecs["len_embedding"][0], 6.0); // "\n\ntext" } -TEST(DocNode, EqualityUsesUid) { - auto global_meta = std::make_shared(); - lazyllm::DocNode lhs("left", "", "same_uid", nullptr, {}, global_meta); - lazyllm::DocNode rhs("right", "", "same_uid", nullptr, {}, global_meta); - lazyllm::DocNode other("other", "", "other_uid", nullptr, {}, global_meta); - +TEST(doc_node, equality_uses_uid) { + lazyllm::DocNode lhs("left", "", "same_uid"); + lazyllm::DocNode rhs("right", "", "same_uid"); EXPECT_TRUE(lhs == rhs); - EXPECT_FALSE(lhs != rhs); + + lazyllm::DocNode other("other", "", "other_uid"); EXPECT_TRUE(lhs != other); } diff --git a/csrc/tests/test_map_params.cpp b/csrc/tests/test_map_params.cpp index c97af2e78..10ecc5bae 100644 --- a/csrc/tests/test_map_params.cpp +++ b/csrc/tests/test_map_params.cpp @@ -1,46 +1,56 @@ #include +#include #include #include #include "map_params.hpp" -TEST(MapParams, ExplicitValueHasHighestPriority) { +TEST(map_params, get_param_value) { lazyllm::MapParams params; params.set_default("chunk_size", 1024u); - const auto value = params.get_param_value("chunk_size", 4096u, 2048u); + auto value = params.get_param_value("chunk_size", 4096u, 2048u); EXPECT_EQ(value, 4096u); -} -TEST(MapParams, ReadsStoredDefaultAndFallback) { - lazyllm::MapParams params; - params.set_default("worker_num", 8u); + value = params.get_param_value("chunk_size", std::nullopt, 4u); + EXPECT_EQ(value, 1024u); - EXPECT_EQ(params.get_param_value("worker_num", std::nullopt, 4u), 8u); EXPECT_EQ(params.get_param_value("missing", std::nullopt, 4u), 4u); } -TEST(MapParams, BulkSetAndTypedGetDefault) { +TEST(map_params, bulk_set_updates_defaults) { lazyllm::MapParams params; lazyllm::MapParams::MapType updates{ {"encoding", std::string("gpt2")}, {"overlap", 128u}, }; params.set_default(updates); + EXPECT_EQ(std::any_cast(params.get_default().at("encoding")), "gpt2"); + EXPECT_EQ(std::any_cast(params.get_default().at("overlap")), 128u); +} + +TEST(map_params, typed_get_default_returns_value) { + lazyllm::MapParams params; + params.set_default("encoding", std::string("gpt2")); + params.set_default("overlap", 128u); const auto encoding = params.get_default("encoding"); const auto overlap = params.get_default("overlap"); - const auto missing = params.get_default("chunk_size"); ASSERT_TRUE(encoding.has_value()); ASSERT_TRUE(overlap.has_value()); EXPECT_EQ(*encoding, "gpt2"); EXPECT_EQ(*overlap, 128u); +} + +TEST(map_params, typed_get_default_returns_nullopt_when_missing) { + lazyllm::MapParams params; + const auto missing = params.get_default("chunk_size"); EXPECT_FALSE(missing.has_value()); } -TEST(MapParams, ResetDefaultClearsState) { +TEST(map_params, reset_default_clears_state) { lazyllm::MapParams params; params.set_default("k", 1); ASSERT_FALSE(params.get_default().empty()); diff --git a/csrc/tests/test_node_transform.cpp b/csrc/tests/test_node_transform.cpp index db338b662..27e61d76c 100644 --- a/csrc/tests/test_node_transform.cpp +++ b/csrc/tests/test_node_transform.cpp @@ -1,125 +1,83 @@ #include -#include -#include -#include - #include "doc_node.hpp" #include "node_transform.hpp" +using namespace lazyllm; + namespace { -class PairTransform final : public lazyllm::NodeTransform { +class PairTransform final : public NodeTransform { public: - explicit PairTransform(int worker_num = 0) : lazyllm::NodeTransform(worker_num) {} + explicit PairTransform(int worker_num = 0) : NodeTransform(worker_num) {} mutable int transform_calls = 0; - std::vector transform(const lazyllm::DocNode* node) const override { + std::vector transform(PDocNode node) const override { ++transform_calls; - const auto shared_meta = node->get_root_node()->_p_global_metadata; + const auto& shared_meta = node->get_root_node()->_p_global_metadata; return { - lazyllm::DocNode( - std::string(node->get_text_view()) + "_A", - "", - "", - nullptr, - {}, - shared_meta), - lazyllm::DocNode( - std::string(node->get_text_view()) + "_B", - "", - "", - nullptr, - {}, - shared_meta), + std::make_shared(std::string(node->get_text_view()) + "_A", shared_meta), + std::make_shared(std::string(node->get_text_view()) + "_B", shared_meta) }; } }; -class SingleTransform final : public lazyllm::NodeTransform { +class SingleTransform final : public NodeTransform { public: - explicit SingleTransform(int worker_num) : lazyllm::NodeTransform(worker_num) {} + explicit SingleTransform(int worker_num) : NodeTransform(worker_num) {} - std::vector transform(const lazyllm::DocNode* node) const override { + std::vector transform(PDocNode node) const override { const auto shared_meta = node->get_root_node()->_p_global_metadata; return { - lazyllm::DocNode( - std::string(node->get_text_view()) + "_child", - "", - "", - nullptr, - {}, - shared_meta) + std::make_shared(std::string(node->get_text_view()) + "_child", shared_meta) }; } }; -} // namespace +std::vector make_roots(size_t count) { + std::vector roots; + roots.reserve(count); + for (size_t i = 0; i < count; ++i) + roots.push_back(std::make_shared("n" + std::to_string(i), "", "uid_" + std::to_string(i))); + return roots; +} -TEST(NodeTransform, BatchForwardCreatesChildrenAndSetsParentGroup) { - auto shared_meta = std::make_shared(); - std::vector roots; - roots.reserve(2); - roots.emplace_back("left", "", "left_uid", nullptr, lazyllm::DocNode::Metadata(), shared_meta); - roots.emplace_back("right", "", "right_uid", nullptr, lazyllm::DocNode::Metadata(), shared_meta); +} // namespace - std::vector root_ptrs{&roots[0], &roots[1]}; - PairTransform transform(0); +TEST(node_transform, call_operator_returns_transform_result) { + auto roots = make_roots(1); + PairTransform transform; const auto direct = transform(roots[0]); EXPECT_EQ(direct.size(), 2u); +} - const auto children = transform.batch_forward(root_ptrs, "split"); - ASSERT_EQ(children.size(), 4u); +TEST(node_transform, batch_forward) { + auto roots = make_roots(2); + PairTransform transform; + const auto children = transform.batch_forward(roots, "split"); + EXPECT_EQ(children.size(), 4u); for (const auto* child : children) { EXPECT_NE(child->get_parent_node(), nullptr); EXPECT_EQ(child->_group_name, "split"); } + for (const auto& root : roots) { EXPECT_TRUE(root.is_children_group_exists("split")); - EXPECT_EQ(root.py_get_children().at("split").size(), 2u); + EXPECT_EQ(root.get_children().at("split").size(), 2u); } } -TEST(NodeTransform, BatchForwardSkipsExistingGroup) { - auto shared_meta = std::make_shared(); - lazyllm::DocNode root("root", "", "uid", nullptr, lazyllm::DocNode::Metadata(), shared_meta); - root.set_children_group("already", {}); - std::vector root_ptrs{&root}; - - PairTransform transform(0); - const auto children = transform.batch_forward(root_ptrs, "already"); - - EXPECT_TRUE(children.empty()); - EXPECT_EQ(transform.transform_calls, 0); -} - -TEST(NodeTransform, BatchForwardSupportsParallelMode) { - auto shared_meta = std::make_shared(); - std::vector roots; - roots.reserve(8); - for (int i = 0; i < 8; ++i) { - roots.emplace_back( - "n" + std::to_string(i), - "", - "uid_" + std::to_string(i), - nullptr, - lazyllm::DocNode::Metadata(), - shared_meta); - } - - std::vector root_ptrs; - root_ptrs.reserve(roots.size()); - for (auto& root : roots) root_ptrs.push_back(&root); - +TEST(node_transform, batch_forward_parallel_mode_respects_worker_num) { + auto roots = make_roots(8); SingleTransform transform(3); - const auto children = transform.batch_forward(root_ptrs, "parallel"); - EXPECT_EQ(transform.worker_num(), 3); + + const auto children = transform.batch_forward(roots, "parallel"); EXPECT_EQ(children.size(), roots.size()); for (const auto& root : roots) { EXPECT_TRUE(root.is_children_group_exists("parallel")); - EXPECT_EQ(root.py_get_children().at("parallel").size(), 1u); + EXPECT_EQ(root.get_children().at("parallel").size(), 1u); } } diff --git a/csrc/tests/test_sentence_splitter.cpp b/csrc/tests/test_sentence_splitter.cpp index 38f6297c2..150002006 100644 --- a/csrc/tests/test_sentence_splitter.cpp +++ b/csrc/tests/test_sentence_splitter.cpp @@ -18,21 +18,20 @@ class TestSentenceSplitter final : public lazyllm::SentenceSplitter { } // namespace -TEST(SentenceSplitter, MergeChunksAppliesOverlapAndTrim) { +TEST(sentence_splitter, merge_chunks_applies_overlap) { TestSentenceSplitter splitter(5, 2); const std::vector splits{ {"ab", false, 2}, {"cd", false, 2}, - {"ef ", false, 2}, - {" ", false, 2}, + {"ef", false, 2}, }; const auto merged = splitter.merge_chunks(splits, 5); - EXPECT_EQ(merged, (std::vector{"abcd", "cdef", "ef"})); + EXPECT_EQ(merged, (std::vector{"abcd", "cdef"})); } -TEST(SentenceSplitter, MergeChunksThrowsOnOversizedSingleSplit) { +TEST(sentence_splitter, merge_chunks_throws_on_oversized_single_split) { TestSentenceSplitter splitter(3, 1); const std::vector splits{ {"abcd", false, 4}, @@ -41,12 +40,17 @@ TEST(SentenceSplitter, MergeChunksThrowsOnOversizedSingleSplit) { EXPECT_THROW((void)splitter.merge_chunks(splits, 3), std::runtime_error); } -TEST(SentenceSplitter, MergeChunksDropsWhitespaceOnlyChunks) { - TestSentenceSplitter splitter(8, 2); +TEST(sentence_splitter, merge_chunks_shrinks_overlap_to_fit_next_chunk) { + TestSentenceSplitter splitter(5, 4); + const std::vector splits{ - {" ", false, 3}, + {"aa", false, 2}, + {"b", false, 1}, + {"cccc", false, 4}, + {"dd", false, 2}, + {"ee", false, 2}, }; - const auto merged = splitter.merge_chunks(splits, 8); - EXPECT_TRUE(merged.empty()); + const auto merged = splitter.merge_chunks(splits, 5); + EXPECT_EQ(merged, (std::vector{"aab", "bcccc", "ddee"})); } diff --git a/csrc/tests/test_smoke.cpp b/csrc/tests/test_smoke.cpp index 4c291b34b..998974c44 100644 --- a/csrc/tests/test_smoke.cpp +++ b/csrc/tests/test_smoke.cpp @@ -2,6 +2,6 @@ #include "lazyllm.hpp" -TEST(LazyLLM, Smoke) { +TEST(lazyllm, smoke) { EXPECT_GT(PYBIND11_VERSION_MAJOR, 0); } diff --git a/csrc/tests/test_text_splitter_base.cpp b/csrc/tests/test_text_splitter_base.cpp index 117a3a0ca..692c45482 100644 --- a/csrc/tests/test_text_splitter_base.cpp +++ b/csrc/tests/test_text_splitter_base.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include @@ -17,21 +16,25 @@ class ByteTokenizer final : public Tokenizer { std::vector encode(const std::string_view& view) const override { std::vector out; out.reserve(view.size()); - for (unsigned char ch : view) out.push_back(static_cast(ch)); + for (unsigned char ch : view) { + out.push_back(static_cast(ch)); + } return out; } std::string decode(const std::vector& token_ids) const override { std::string out; out.reserve(token_ids.size()); - for (int token_id : token_ids) out.push_back(static_cast(token_id)); + for (int token_id : token_ids) { + out.push_back(static_cast(token_id)); + } return out; } }; class TestTextSplitter final : public lazyllm::TextSplitterBase { public: - TestTextSplitter(unsigned chunk_size, unsigned overlap) + TestTextSplitter(unsigned chunk_size, unsigned overlap = 0) : lazyllm::TextSplitterBase(chunk_size, overlap, 0) {} using lazyllm::TextSplitterBase::merge_chunks; @@ -40,43 +43,49 @@ class TestTextSplitter final : public lazyllm::TextSplitterBase { } // namespace -TEST(TextSplitterBase, ConstructorValidatesParameters) { - EXPECT_THROW((void)TestTextSplitter(10, 11), std::runtime_error); - EXPECT_THROW((void)TestTextSplitter(0, 0), std::runtime_error); +TEST(text_splitter_base, exception_management) { + EXPECT_THROW((void)lazyllm::TextSplitterBase(10, 11), std::runtime_error); + EXPECT_THROW((void)lazyllm::TextSplitterBase(0, 0), std::runtime_error); } -TEST(TextSplitterBase, SplitTextKeepingSeparator) { +TEST(text_splitter_base, split_text_keep_separator_returns_segments) { const auto parts = lazyllm::TextSplitterBase::split_text_while_keeping_separator("a--b--", "--"); ASSERT_EQ(parts.size(), 2u); EXPECT_EQ(parts[0], "a--"); EXPECT_EQ(parts[1], "b--"); +} - const auto leading_sep = lazyllm::TextSplitterBase::split_text_while_keeping_separator("--x", "--"); - ASSERT_EQ(leading_sep.size(), 1u); - EXPECT_EQ(leading_sep[0], "x"); +TEST(text_splitter_base, split_text_keep_separator_skips_leading_separator) { + const auto parts = lazyllm::TextSplitterBase::split_text_while_keeping_separator("--x", "--"); + ASSERT_EQ(parts.size(), 1u); + EXPECT_EQ(parts[0], "x"); } -TEST(TextSplitterBase, SplitTextChecksMetadataBudget) { - TestTextSplitter splitter(60, 0); +TEST(text_splitter_base, split_text_throws_when_metadata_exceeds_chunk_size) { + lazyllm::TextSplitterBase splitter(60, 0); splitter.set_tokenizer(std::make_shared()); - EXPECT_THROW((void)splitter.split_text("abc", 60), std::runtime_error); +} + +TEST(text_splitter_base, split_text_throws_when_metadata_budget_too_small) { + lazyllm::TextSplitterBase splitter(60, 0); + splitter.set_tokenizer(std::make_shared()); EXPECT_THROW((void)splitter.split_text("abc", 11), std::runtime_error); } -TEST(TextSplitterBase, SplitRecursiveFallsBackToCharLevel) { +TEST(text_splitter_base, split_recursive_falls_back_to_char_level) { TestTextSplitter splitter(100, 0); splitter.set_tokenizer(std::make_shared()); const auto chunks = splitter.split_recursive("abc", 2); ASSERT_EQ(chunks.size(), 3u); - EXPECT_EQ(chunks[0].view, "a"); - EXPECT_EQ(chunks[1].view, "b"); - EXPECT_EQ(chunks[2].view, "c"); + EXPECT_EQ(chunks[0], "a"); + EXPECT_EQ(chunks[1], "b"); + EXPECT_EQ(chunks[2], "c"); EXPECT_FALSE(chunks[0].is_sentence); } -TEST(TextSplitterBase, MergeChunksUsesOverlap) { +TEST(text_splitter_base, merge_chunks_uses_overlap) { TestTextSplitter splitter(100, 1); splitter.set_tokenizer(std::make_shared()); @@ -90,22 +99,21 @@ TEST(TextSplitterBase, MergeChunksUsesOverlap) { EXPECT_EQ(merged, (std::vector{"ab", "bcd", "def"})); } -TEST(TextSplitterBase, TransformReturnsChunkNodesAndSupportsNull) { - TestTextSplitter splitter(100, 0); +TEST(text_splitter_base, transform_returns_chunk_nodes) { + lazyllm::TextSplitterBase splitter(100, 0); splitter.set_tokenizer(std::make_shared()); - auto global_meta = std::make_shared(); - lazyllm::DocNode node("hello", "", "", nullptr, {}, global_meta); + lazyllm::PDocNode node = std::make_shared("hello"); - const auto chunks = splitter.transform(&node); + auto chunks = splitter.transform(node); ASSERT_EQ(chunks.size(), 1u); - EXPECT_EQ(chunks[0].get_text(lazyllm::MetadataMode::NONE), "hello"); + EXPECT_EQ(chunks[0]->get_text(), "hello"); - const auto empty = splitter.transform(nullptr); - EXPECT_TRUE(empty.empty()); + chunks = splitter.transform(nullptr); + EXPECT_TRUE(chunks.empty()); } -TEST(TextSplitterBase, FromTiktokenEncoderThrowsOnInvalidName) { - TestTextSplitter splitter(100, 0); +TEST(text_splitter_base, from_tiktoken_encoder_throws_on_invalid_name) { + lazyllm::TextSplitterBase splitter(100, 0); EXPECT_THROW((void)splitter.from_tiktoken_encoder("definitely_unknown"), std::runtime_error); } diff --git a/csrc/tests/test_thread_pool.cpp b/csrc/tests/test_thread_pool.cpp index 6f51bbc16..b834fe63c 100644 --- a/csrc/tests/test_thread_pool.cpp +++ b/csrc/tests/test_thread_pool.cpp @@ -5,17 +5,20 @@ #include "thread_pool.hpp" -TEST(ThreadPool, ExecutesTasksAndReturnsValues) { +TEST(thread_pool, executes_tasks) { ThreadPool pool(3); auto f1 = pool.enqueue([] { return 1 + 2; }); - auto f2 = pool.enqueue([](int v) { return v * 2; }, 5); - EXPECT_EQ(f1.get(), 3); +} + +TEST(thread_pool, returns_values_from_futures) { + ThreadPool pool(3); + auto f2 = pool.enqueue([](int v) { return v * 2; }, 5); EXPECT_EQ(f2.get(), 10); } -TEST(ThreadPool, PropagatesTaskExceptionThroughFuture) { +TEST(thread_pool, propagates_task_exception_through_future) { ThreadPool pool(1); auto failing = pool.enqueue([]() -> int { throw std::runtime_error("boom"); diff --git a/csrc/tests/test_tokenizer.cpp b/csrc/tests/test_tokenizer.cpp index e2a5e0e3e..346aa4e39 100644 --- a/csrc/tests/test_tokenizer.cpp +++ b/csrc/tests/test_tokenizer.cpp @@ -27,7 +27,7 @@ class IdentityTokenizer final : public Tokenizer { } // namespace -TEST(Tokenizer, AbstractInterfaceViaDerivedClass) { +TEST(tokenizer, abstract_interface_via_derived_class) { std::unique_ptr tokenizer = std::make_unique(); const auto ids = tokenizer->encode("abc"); @@ -35,7 +35,7 @@ TEST(Tokenizer, AbstractInterfaceViaDerivedClass) { EXPECT_EQ(tokenizer->decode(ids), "abc"); } -TEST(TiktokenTokenizer, RoundTripEncoding) { +TEST(tiktoken_tokenizer, round_trip_encoding) { TiktokenTokenizer tokenizer("gpt2"); const std::string text = "hello tokenizer"; @@ -44,13 +44,13 @@ TEST(TiktokenTokenizer, RoundTripEncoding) { EXPECT_EQ(tokenizer.decode(token_ids), text); } -TEST(TiktokenTokenizer, AliasNamesMapToSameEncoding) { +TEST(tiktoken_tokenizer, alias_names_map_to_same_encoding) { TiktokenTokenizer gpt2("gpt2"); TiktokenTokenizer r50k("r50k_base"); EXPECT_EQ(gpt2.encode("same input"), r50k.encode("same input")); } -TEST(TiktokenTokenizer, UnknownEncodingThrows) { +TEST(tiktoken_tokenizer, unknown_encoding_throws) { EXPECT_THROW((void)TiktokenTokenizer("unknown_model"), std::runtime_error); } diff --git a/csrc/tests/test_unicode_processor.cpp b/csrc/tests/test_unicode_processor.cpp index e101fb681..ee86d29e9 100644 --- a/csrc/tests/test_unicode_processor.cpp +++ b/csrc/tests/test_unicode_processor.cpp @@ -5,37 +5,26 @@ #include "unicode_processor.hpp" -namespace { - -std::vector ToStrings(const std::vector& views) { - std::vector out; - out.reserve(views.size()); - for (const auto& view : views) out.emplace_back(view); - return out; -} - -} // namespace - -TEST(UnicodeProcessor, SplitToCharsSupportsMultibyte) { +TEST(unicode_processor, split_to_chars_supports_multibyte) { const std::string text = "a你🙂"; const lazyllm::UnicodeProcessor processor(text); - const auto chars = ToStrings(processor.split_to_chars()); - EXPECT_EQ(chars, (std::vector{"a", "你", "🙂"})); + const auto chars = processor.split_to_chars(); + EXPECT_EQ(chars, (std::vector{"a", "你", "🙂"})); } -TEST(UnicodeProcessor, SplitByPunctuationHandlesAsciiAndCjk) { - const std::string text = "Hello,world。你好!"; +TEST(unicode_processor, split_by_punctuation_for_ascii) { + const std::string text = "Hello,world!"; const lazyllm::UnicodeProcessor processor(text); - const auto chunks = ToStrings(processor.split_by_punctuation()); - EXPECT_EQ(chunks, (std::vector{"Hello,", "world。", "你好!"})); + const auto chunks = processor.split_by_punctuation(); + EXPECT_EQ(chunks, (std::vector{"Hello,", "world!"})); } -TEST(UnicodeProcessor, SplitByPunctuationFallbackWhenOnlyPunctuation) { - const std::string text = "!!!"; +TEST(unicode_processor, split_by_punctuation_for_cjk) { + const std::string text = "你好。世界!"; const lazyllm::UnicodeProcessor processor(text); - const auto chunks = ToStrings(processor.split_by_punctuation()); - EXPECT_EQ(chunks, (std::vector{"!!!"})); + const auto chunks = processor.split_by_punctuation(); + EXPECT_EQ(chunks, (std::vector{"你好。", "世界!"})); } diff --git a/csrc/tests/test_utils.cpp b/csrc/tests/test_utils.cpp index 008510ef5..58660e91b 100644 --- a/csrc/tests/test_utils.cpp +++ b/csrc/tests/test_utils.cpp @@ -7,39 +7,60 @@ #include "utils.hpp" -TEST(Utils, JoinLinesAndConcatVector) { +TEST(utils, join_lines_returns_empty_for_empty_input) { EXPECT_EQ(lazyllm::JoinLines({}), ""); +} + +TEST(utils, join_lines_uses_newline_separator_by_default) { EXPECT_EQ(lazyllm::JoinLines({"a", "b", "c"}), "a\nb\nc"); +} + +TEST(utils, join_lines_supports_custom_delimiter) { EXPECT_EQ(lazyllm::JoinLines({"a", "b", "c"}, ','), "a,b,c"); +} +TEST(utils, concat_vector_appends_right_sequence) { const auto merged = lazyllm::ConcatVector(std::vector{1, 2}, std::vector{3, 4}); EXPECT_EQ(merged, (std::vector{1, 2, 3, 4})); } -TEST(Utils, SetUnionAndSetDiff) { +TEST(utils, set_union_returns_all_unique_values) { const std::set left{1, 2, 3}; const std::set right{3, 4}; - EXPECT_EQ(lazyllm::SetUnion(left, right), (std::set{1, 2, 3, 4})); +} + +TEST(utils, set_diff_returns_only_left_unique_values) { + const std::set left{1, 2, 3}; + const std::set right{3, 4}; EXPECT_EQ(lazyllm::SetDiff(left, right), (std::set{1, 2})); } -TEST(Utils, HexUuidAndAdjacency) { +TEST(utils, to_hex_returns_lowercase_hex_text) { EXPECT_EQ(lazyllm::to_hex(255u), "ff"); +} +TEST(utils, generate_uuid_matches_expected_pattern) { const std::string uuid = lazyllm::GenerateUUID(); const std::regex pattern("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"); EXPECT_TRUE(std::regex_match(uuid, pattern)); +} + +TEST(utils, is_adjacent_returns_true_for_contiguous_views) { + const std::string text = "abcdef"; + const std::string_view left = std::string_view(text.data(), 3); + const std::string_view right = std::string_view(text.data() + 3, 3); + EXPECT_TRUE(lazyllm::is_adjacent(left, right)); +} +TEST(utils, is_adjacent_returns_false_for_non_contiguous_views) { const std::string text = "abcdef"; const std::string_view left = std::string_view(text.data(), 3); - const std::string_view right_adjacent = std::string_view(text.data() + 3, 3); - const std::string_view right_not_adjacent = std::string_view(text.data() + 4, 2); - EXPECT_TRUE(lazyllm::is_adjacent(left, right_adjacent)); - EXPECT_FALSE(lazyllm::is_adjacent(left, right_not_adjacent)); + const std::string_view right = std::string_view(text.data() + 4, 2); + EXPECT_FALSE(lazyllm::is_adjacent(left, right)); } -TEST(Utils, ChunkOperatorAccumulatesFields) { +TEST(utils, chunk_operator_plus_equals_accumulates_fields) { lazyllm::Chunk l{"ab", true, 2}; lazyllm::Chunk r{"cd", false, 3}; @@ -49,7 +70,7 @@ TEST(Utils, ChunkOperatorAccumulatesFields) { EXPECT_EQ(l.token_size, 5); } -TEST(Utils, MetadataKeyConstantsExposed) { +TEST(utils, rag_metadata_keys_constants_are_exposed) { EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_PATH, "lazyllm_doc_path"); EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_ID, "docid"); } From ac9dad3525d76c969f4f0e05f3833e026219f487 Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 4 Mar 2026 16:27:51 +0800 Subject: [PATCH 32/40] cpp env switch --- lazyllm/cpp.py | 39 ++++++++++++++-- lazyllm/tools/rag/__init__.py | 4 ++ lazyllm/tools/rag/transform/base.py | 4 ++ tests/basic_tests/test_cpp_override.py | 61 ++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 tests/basic_tests/test_cpp_override.py diff --git a/lazyllm/cpp.py b/lazyllm/cpp.py index 66e213d70..193a3fe6e 100644 --- a/lazyllm/cpp.py +++ b/lazyllm/cpp.py @@ -1,4 +1,35 @@ -try: - from .lazyllm_cpp import * # noqa F403 -except ImportError: - pass +import importlib +import logging +import os +from typing import Dict, Iterable + +LOG = logging.getLogger(__name__) + +_CPP_ENABLE_ENV = 'LAZYLLM_ENABLE_CPP_OVERRIDE' + + +def _is_enabled() -> bool: + value = os.getenv(_CPP_ENABLE_ENV) + return value is not None and (value == '1' or value.lower() == 'true') + + +def _load_cpp_module(): + try: + return importlib.import_module('lazyllm.lazyllm_cpp') + except ImportError as e: + LOG.warning('C++ override is enabled but lazyllm_cpp import failed: %s', e) + return None + + +def override_with_cpp_exports(module_globals: Dict[str, object], names: Iterable[str]): + if not _is_enabled(): + return + + cpp_module = _load_cpp_module() + if cpp_module is None: + return + + for name in names: + if not hasattr(cpp_module, name): + LOG.error(f'C++ module: {name} does not exist.') + module_globals[name] = getattr(cpp_module, name) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index ec4489dab..48ca90c6a 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -79,3 +79,7 @@ 'JSONLSplitter', 'SchemaExtractor' ] + +from lazyllm.cpp import override_with_cpp_exports # noqa E402 +override_with_cpp_exports(globals(), ['DocNode', 'NodeTransform', 'SentenceSplitter']) +del override_with_cpp_exports diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index 1ea20a456..c546e3754 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -452,3 +452,7 @@ def _split(self, text: str, chunk_size: int) -> List[_Split]: def _merge(self, splits: List[_Split], chunk_size: int) -> List[str]: return [split.text for split in splits] + +from lazyllm.cpp import override_with_cpp_exports # noqa E402 +override_with_cpp_exports(globals(), ['_TextSplitterBase']) +del override_with_cpp_exports \ No newline at end of file diff --git a/tests/basic_tests/test_cpp_override.py b/tests/basic_tests/test_cpp_override.py new file mode 100644 index 000000000..74a6417a8 --- /dev/null +++ b/tests/basic_tests/test_cpp_override.py @@ -0,0 +1,61 @@ +import sys +import types + +import lazyllm.cpp as cpp + + +def test_cpp_override_disabled(monkeypatch): + monkeypatch.delenv('LAZYLLM_ENABLE_CPP_OVERRIDE', raising=False) + + fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': object()} + assert cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) == [] + + +def test_cpp_override_applies_by_same_name(monkeypatch): + class PySentenceSplitter: + pass + + class CppSentenceSplitter: + pass + + cpp_module = types.ModuleType('lazyllm.lazyllm_cpp') + cpp_module.SentenceSplitter = CppSentenceSplitter + + monkeypatch.setitem(sys.modules, 'lazyllm.lazyllm_cpp', cpp_module) + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', 'ON') + + fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': PySentenceSplitter} + applied = cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) + + assert applied == ['SentenceSplitter'] + assert fake_globals['SentenceSplitter'] is CppSentenceSplitter + + +def test_cpp_override_skips_missing_symbol(monkeypatch): + cpp_module = types.ModuleType('lazyllm.lazyllm_cpp') + monkeypatch.setitem(sys.modules, 'lazyllm.lazyllm_cpp', cpp_module) + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', 'true') + + marker = object() + fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': marker} + applied = cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) + + assert applied == [] + assert fake_globals['SentenceSplitter'] is marker + + +def test_cpp_override_import_failure(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', 'on') + monkeypatch.delitem(sys.modules, 'lazyllm.lazyllm_cpp', raising=False) + + real_import_module = cpp.importlib.import_module + + def _raise_import_error(name): + if name == 'lazyllm.lazyllm_cpp': + raise ImportError('mock import error') + return real_import_module(name) + + monkeypatch.setattr(cpp.importlib, 'import_module', _raise_import_error) + + fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': object()} + assert cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) == [] From 4ab5a9390876d9d3f7de5b89271d3f152c7c411e Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 4 Mar 2026 16:28:19 +0800 Subject: [PATCH 33/40] no need to test cpp override --- tests/basic_tests/test_cpp_override.py | 61 -------------------------- 1 file changed, 61 deletions(-) delete mode 100644 tests/basic_tests/test_cpp_override.py diff --git a/tests/basic_tests/test_cpp_override.py b/tests/basic_tests/test_cpp_override.py deleted file mode 100644 index 74a6417a8..000000000 --- a/tests/basic_tests/test_cpp_override.py +++ /dev/null @@ -1,61 +0,0 @@ -import sys -import types - -import lazyllm.cpp as cpp - - -def test_cpp_override_disabled(monkeypatch): - monkeypatch.delenv('LAZYLLM_ENABLE_CPP_OVERRIDE', raising=False) - - fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': object()} - assert cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) == [] - - -def test_cpp_override_applies_by_same_name(monkeypatch): - class PySentenceSplitter: - pass - - class CppSentenceSplitter: - pass - - cpp_module = types.ModuleType('lazyllm.lazyllm_cpp') - cpp_module.SentenceSplitter = CppSentenceSplitter - - monkeypatch.setitem(sys.modules, 'lazyllm.lazyllm_cpp', cpp_module) - monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', 'ON') - - fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': PySentenceSplitter} - applied = cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) - - assert applied == ['SentenceSplitter'] - assert fake_globals['SentenceSplitter'] is CppSentenceSplitter - - -def test_cpp_override_skips_missing_symbol(monkeypatch): - cpp_module = types.ModuleType('lazyllm.lazyllm_cpp') - monkeypatch.setitem(sys.modules, 'lazyllm.lazyllm_cpp', cpp_module) - monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', 'true') - - marker = object() - fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': marker} - applied = cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) - - assert applied == [] - assert fake_globals['SentenceSplitter'] is marker - - -def test_cpp_override_import_failure(monkeypatch): - monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', 'on') - monkeypatch.delitem(sys.modules, 'lazyllm.lazyllm_cpp', raising=False) - - real_import_module = cpp.importlib.import_module - - def _raise_import_error(name): - if name == 'lazyllm.lazyllm_cpp': - raise ImportError('mock import error') - return real_import_module(name) - - monkeypatch.setattr(cpp.importlib, 'import_module', _raise_import_error) - - fake_globals = {'__name__': 'fake.module', 'SentenceSplitter': object()} - assert cpp.override_with_cpp_exports(fake_globals, ('SentenceSplitter',)) == [] From b38affc1ad08b5b92036bb87e0e23d7ee597e356 Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 5 Mar 2026 14:27:16 +0800 Subject: [PATCH 34/40] cpp tests passed. --- csrc/adaptor/document_store.hpp | 6 +++--- csrc/binding/export_doc_node.cpp | 2 +- csrc/binding/export_sentence_splitter.cpp | 4 ++-- csrc/binding/export_text_splitter_base.cpp | 4 ++-- csrc/core/include/doc_node.hpp | 7 +++++++ csrc/tests/test_doc_node.cpp | 3 +-- csrc/tests/test_node_transform.cpp | 14 +++++++------- csrc/tests/test_text_splitter_base.cpp | 6 +++--- 8 files changed, 26 insertions(+), 20 deletions(-) diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index c2411a26a..63bb6a6e4 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -61,15 +61,15 @@ class LAZYLLM_HIDDEN DocumentStore : public AdaptorBaseWrapper { for(auto& [current_group_name, group] : _node_groups_map) { if (group._parent != group_name) continue; if (!std::any_cast(call("is_group_active", {{"group", current_group_name}}))) continue; - auto nodes_in_group = std::any_cast>(call("get_nodes", { + auto nodes_in_group = std::any_cast>(call("get_nodes", { {"group_name", current_group_name}, {"kb_id", kb_id}, {"doc_ids", std::vector({doc_id})} })); - std::vector children; + std::vector children; children.reserve(nodes_in_group.size()); - for (auto* n : nodes_in_group) + for (auto n : nodes_in_group) if (n->get_parent_node() == node) children.push_back(n); out[current_group_name] = children; } diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 9270ef489..1eb61804f 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -73,7 +73,7 @@ lazyllm::DocNode init( node.set_embedding_vec(key, vec); } if (content) { - if (const auto* s = std::get_if(&*content)) + if (const std::string* s = std::get_if(&*content)) node.set_root_text(std::move(*s)); else node.set_root_texts(std::get>(*content)); diff --git a/csrc/binding/export_sentence_splitter.cpp b/csrc/binding/export_sentence_splitter.cpp index d1e68d2d9..3f32d2165 100644 --- a/csrc/binding/export_sentence_splitter.cpp +++ b/csrc/binding/export_sentence_splitter.cpp @@ -15,9 +15,9 @@ class PySentenceSplitter final : public lazyllm::SentenceSplitter { public: using lazyllm::SentenceSplitter::SentenceSplitter; - std::vector transform(const lazyllm::DocNode* node) const override { + std::vector transform(const lazyllm::PDocNode node) const override { PYBIND11_OVERRIDE( - std::vector, + std::vector, lazyllm::SentenceSplitter, transform, node diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp index 1b3f05d73..42bbeb3d7 100644 --- a/csrc/binding/export_text_splitter_base.cpp +++ b/csrc/binding/export_text_splitter_base.cpp @@ -81,9 +81,9 @@ class PyTextSplitterBase final : public lazyllm::TextSplitterBase { public: using lazyllm::TextSplitterBase::TextSplitterBase; - std::vector transform(const lazyllm::DocNode* node) const override { + std::vector transform(const lazyllm::PDocNode node) const override { PYBIND11_OVERRIDE( - std::vector, + std::vector, lazyllm::TextSplitterBase, transform, node diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index 8f7349b0f..b06381cc3 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -19,6 +19,7 @@ namespace lazyllm { enum class MetadataMode { ALL, EMBED, LLM, NONE }; +class DocNode; using PDocNode = std::shared_ptr; class DocNode { @@ -72,6 +73,8 @@ class DocNode { : DocNode(text, "", "", nullptr, {}, global_metadata) { set_root_text(std::move(text)); } + explicit DocNode(const char* text, const std::shared_ptr& global_metadata = {}) + : DocNode(std::string(text == nullptr ? "" : text), global_metadata) {} DocNode(const DocNode&) = default; DocNode& operator=(const DocNode&) = default; @@ -119,6 +122,10 @@ class DocNode { const auto& metadata_string = get_metadata_string(mode); return metadata_string + "\n\n" + std::string(_text_view); } + void set_root_text(const std::string& text) { + _p_root_text = std::make_shared(text); + set_text_view(*_p_root_text); + } void set_root_text(std::string&& text) { _p_root_text = std::make_shared(std::move(text)); set_text_view(*_p_root_text); diff --git a/csrc/tests/test_doc_node.cpp b/csrc/tests/test_doc_node.cpp index 7f9dec7fc..655c35e20 100644 --- a/csrc/tests/test_doc_node.cpp +++ b/csrc/tests/test_doc_node.cpp @@ -112,8 +112,7 @@ TEST(doc_node, children_caching) { EXPECT_EQ(adaptor->call_count, 1); ASSERT_TRUE(first.find("cached") != first.end()); ASSERT_EQ(first.at("cached").size(), 1u); - EXPECT_EQ(first.at("cached")[0], &child); - EXPECT_EQ(second.at("cached")[0], &child); + EXPECT_EQ(first.at("cached")[0], second.at("cached")[0]); } TEST(doc_node, doc_path_reads_from_root_global_metadata) { diff --git a/csrc/tests/test_node_transform.cpp b/csrc/tests/test_node_transform.cpp index 27e61d76c..41833b0d7 100644 --- a/csrc/tests/test_node_transform.cpp +++ b/csrc/tests/test_node_transform.cpp @@ -58,14 +58,14 @@ TEST(node_transform, batch_forward) { const auto children = transform.batch_forward(roots, "split"); EXPECT_EQ(children.size(), 4u); - for (const auto* child : children) { + for (auto child : children) { EXPECT_NE(child->get_parent_node(), nullptr); EXPECT_EQ(child->_group_name, "split"); } - for (const auto& root : roots) { - EXPECT_TRUE(root.is_children_group_exists("split")); - EXPECT_EQ(root.get_children().at("split").size(), 2u); + for (auto root : roots) { + EXPECT_TRUE(root->is_children_group_exists("split")); + EXPECT_EQ(root->get_children().at("split").size(), 2u); } } @@ -76,8 +76,8 @@ TEST(node_transform, batch_forward_parallel_mode_respects_worker_num) { const auto children = transform.batch_forward(roots, "parallel"); EXPECT_EQ(children.size(), roots.size()); - for (const auto& root : roots) { - EXPECT_TRUE(root.is_children_group_exists("parallel")); - EXPECT_EQ(root.get_children().at("parallel").size(), 1u); + for (auto root : roots) { + EXPECT_TRUE(root->is_children_group_exists("parallel")); + EXPECT_EQ(root->get_children().at("parallel").size(), 1u); } } diff --git a/csrc/tests/test_text_splitter_base.cpp b/csrc/tests/test_text_splitter_base.cpp index 692c45482..0b0fbdeb7 100644 --- a/csrc/tests/test_text_splitter_base.cpp +++ b/csrc/tests/test_text_splitter_base.cpp @@ -79,9 +79,9 @@ TEST(text_splitter_base, split_recursive_falls_back_to_char_level) { const auto chunks = splitter.split_recursive("abc", 2); ASSERT_EQ(chunks.size(), 3u); - EXPECT_EQ(chunks[0], "a"); - EXPECT_EQ(chunks[1], "b"); - EXPECT_EQ(chunks[2], "c"); + EXPECT_EQ(chunks[0].view, "a"); + EXPECT_EQ(chunks[1].view, "b"); + EXPECT_EQ(chunks[2].view, "c"); EXPECT_FALSE(chunks[0].is_sentence); } From ee3ecbc8251bb073fd4e26adcc2156fcb2ab2504 Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 5 Mar 2026 19:30:29 +0800 Subject: [PATCH 35/40] install and third parties so. --- .gitignore | 1 + csrc/CMakeLists.txt | 19 ++++++++++++++++++- csrc/scripts/build_debug.sh | 8 ++++++-- lazyllm/cpp.py | 8 +++++--- lazyllm/tools/rag/transform/base.py | 2 +- 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index c0e9cec43..c9b8793e2 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ docs/zh/assets build* lazyllm_cpp.egg-info/ !build*.sh +lazyllm/cpp_lib/ \ No newline at end of file diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index e533b90b0..85e5f1e2c 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -34,6 +34,16 @@ pybind11_add_module(${INTERFACE_TARGET_NAME} ${LAZYLLM_BINDING_SOURCES}) target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core lazyllm_adaptor) target_compile_options(${INTERFACE_TARGET_NAME} PRIVATE -Werror -Wshadow) +# Ensure lazyllm_cpp can find third-party shared libraries under lazyllm/cpp_lib. +set(_lazyllm_cpp_rpath "$ORIGIN/cpp_lib") +if (APPLE) + set(_lazyllm_cpp_rpath "@loader_path/cpp_lib") +endif() +set_target_properties(${INTERFACE_TARGET_NAME} PROPERTIES + BUILD_RPATH "${_lazyllm_cpp_rpath}" + INSTALL_RPATH "${_lazyllm_cpp_rpath}" +) + if (CMAKE_BUILD_TYPE STREQUAL "Debug") # SHOW_SYMBOL set_target_properties(${INTERFACE_TARGET_NAME} PROPERTIES CXX_VISIBILITY_PRESET default) @@ -41,7 +51,14 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") endif() # Install -install(TARGETS ${INTERFACE_TARGET_NAME} LIBRARY DESTINATION lazyllm) +install(TARGETS ${INTERFACE_TARGET_NAME} + LIBRARY DESTINATION lazyllm COMPONENT lazyllm_cpp + RUNTIME DESTINATION lazyllm COMPONENT lazyllm_cpp +) +install(TARGETS tiktoken utf8proc + LIBRARY DESTINATION lazyllm/cpp_lib COMPONENT lazyllm_cpp + RUNTIME DESTINATION lazyllm/cpp_lib COMPONENT lazyllm_cpp +) # TESTS diff --git a/csrc/scripts/build_debug.sh b/csrc/scripts/build_debug.sh index 08adee4c7..d67bc3f6e 100644 --- a/csrc/scripts/build_debug.sh +++ b/csrc/scripts/build_debug.sh @@ -1,7 +1,11 @@ #!/usr/bin/env bash +# Run at LazyLLM/. set -euo pipefail -cmake -S csrc -B build \ +cmake -S ./csrc -B ./build \ -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ -DCMAKE_BUILD_TYPE=Debug -cmake --build build +cmake --build ./build + +# Install into ./lazyllm (prefix=. + LIBRARY DESTINATION lazyllm). +cmake --install ./build --prefix . --component lazyllm_cpp diff --git a/lazyllm/cpp.py b/lazyllm/cpp.py index 193a3fe6e..cd292d3d4 100644 --- a/lazyllm/cpp.py +++ b/lazyllm/cpp.py @@ -29,7 +29,9 @@ def override_with_cpp_exports(module_globals: Dict[str, object], names: Iterable if cpp_module is None: return + missing = object() for name in names: - if not hasattr(cpp_module, name): - LOG.error(f'C++ module: {name} does not exist.') - module_globals[name] = getattr(cpp_module, name) + cpp_export = getattr(cpp_module, name, missing) + if cpp_export is missing: + raise AttributeError(f"module 'lazyllm.lazyllm_cpp' has no attribute '{name}'") + module_globals[name] = cpp_export diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index f28661888..d6d5e490e 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -744,4 +744,4 @@ def __bool__(self) -> bool: from lazyllm.cpp import override_with_cpp_exports # noqa E402 override_with_cpp_exports(globals(), ['_TextSplitterBase']) -del override_with_cpp_exports \ No newline at end of file +del override_with_cpp_exports From 42252a7de4ee29c8c3888d5cde8e55b42a39124a Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 6 Mar 2026 10:21:08 +0800 Subject: [PATCH 36/40] Reuse python side tests. --- tests/cpp_ext_tests/test_doc_node.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 tests/cpp_ext_tests/test_doc_node.py diff --git a/tests/cpp_ext_tests/test_doc_node.py b/tests/cpp_ext_tests/test_doc_node.py deleted file mode 100644 index 40991bb87..000000000 --- a/tests/cpp_ext_tests/test_doc_node.py +++ /dev/null @@ -1,13 +0,0 @@ -class TestDocNode: - def setup_method(self): - from lazyllm import lazyllm_cpp - self.lazyllm_cpp = lazyllm_cpp - - def test_doc_node_set_get(self): - node = self.lazyllm_cpp.DocNode() - assert node.get_text() == '' - node.set_text('hello') - assert node.get_text() == 'hello' - - node2 = self.lazyllm_cpp.DocNode('world') - assert node2.get_text() == 'world' From 06eabd460ed3a591dfcc9954c539fa9117151895 Mon Sep 17 00:00:00 2001 From: yzh Date: Wed, 11 Mar 2026 18:31:34 +0800 Subject: [PATCH 37/40] LD_PRELOAD --- csrc/CMakeLists.txt | 47 ++++++++++++++++++++++++++++------- csrc/cmake/tests.cmake | 22 ++-------------- csrc/scripts/build_debug.sh | 6 ++--- csrc/scripts/build_release.sh | 4 +++ lazyllm/cpp.py | 31 ++++++++++------------- 5 files changed, 60 insertions(+), 50 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 85e5f1e2c..4bb12f105 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -34,15 +34,44 @@ pybind11_add_module(${INTERFACE_TARGET_NAME} ${LAZYLLM_BINDING_SOURCES}) target_link_libraries(${INTERFACE_TARGET_NAME} PRIVATE lazyllm_core lazyllm_adaptor) target_compile_options(${INTERFACE_TARGET_NAME} PRIVATE -Werror -Wshadow) -# Ensure lazyllm_cpp can find third-party shared libraries under lazyllm/cpp_lib. -set(_lazyllm_cpp_rpath "$ORIGIN/cpp_lib") -if (APPLE) - set(_lazyllm_cpp_rpath "@loader_path/cpp_lib") -endif() -set_target_properties(${INTERFACE_TARGET_NAME} PROPERTIES - BUILD_RPATH "${_lazyllm_cpp_rpath}" - INSTALL_RPATH "${_lazyllm_cpp_rpath}" -) +# Runtime loader configuration per platform. +set(_lazyllm_cpp_rpath "") +set(LAZYLLM_TEST_RUNTIME_ENV "" CACHE INTERNAL "Runtime env for LazyLLM C++ tests" FORCE) +if (WIN32) + # Windows has no ELF rpath; loader resolution is driven by PATH and DLL search order. + # Keep test runtime env empty by default. +elseif (APPLE) + # Ensure lazyllm_cpp can find third-party dylibs under lazyllm/cpp_lib. + list(APPEND _lazyllm_cpp_rpath "@loader_path/cpp_lib") +else () + # Ensure lazyllm_cpp can find third-party shared libraries under lazyllm/cpp_lib. + list(APPEND _lazyllm_cpp_rpath "$ORIGIN/cpp_lib") + # Use DT_RPATH (instead of DT_RUNPATH) so the extension's own runtime + # search path can take precedence over host interpreter bundled libs. + target_link_options(${INTERFACE_TARGET_NAME} PRIVATE -Wl,--disable-new-dtags) + + # Resolve libstdc++ from the active C++ compiler and include it in rpath. + execute_process( + COMMAND ${CMAKE_CXX_COMPILER} -print-file-name=libstdc++.so.6 + OUTPUT_VARIABLE LIBSTDCPP_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if (LIBSTDCPP_PATH AND NOT LIBSTDCPP_PATH STREQUAL "libstdc++.so.6") + get_filename_component(LIBSTDCPP_DIR "${LIBSTDCPP_PATH}" DIRECTORY) + if (LIBSTDCPP_DIR) + list(APPEND _lazyllm_cpp_rpath "${LIBSTDCPP_DIR}") + set(LAZYLLM_TEST_RUNTIME_ENV "LD_LIBRARY_PATH=${LIBSTDCPP_DIR}:$ENV{LD_LIBRARY_PATH}" + CACHE INTERNAL "Runtime env for LazyLLM C++ tests" FORCE) + endif () + endif () +endif () + +if (_lazyllm_cpp_rpath) + set_target_properties(${INTERFACE_TARGET_NAME} PROPERTIES + BUILD_RPATH "${_lazyllm_cpp_rpath}" + INSTALL_RPATH "${_lazyllm_cpp_rpath}" + ) +endif () if (CMAKE_BUILD_TYPE STREQUAL "Debug") # SHOW_SYMBOL diff --git a/csrc/cmake/tests.cmake b/csrc/cmake/tests.cmake index 5f72a0890..ce18c6b4a 100644 --- a/csrc/cmake/tests.cmake +++ b/csrc/cmake/tests.cmake @@ -10,24 +10,6 @@ FetchContent_MakeAvailable(googletest) enable_testing() include(GoogleTest) -# Resolve libstdc++ from the active C++ compiler. This avoids loading an -# older copy from a preloaded environment (for example Conda) at test runtime. -execute_process( - COMMAND ${CMAKE_CXX_COMPILER} -print-file-name=libstdc++.so.6 - OUTPUT_VARIABLE LIBSTDCPP_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE -) - -set(TEST_RUNTIME_ENV "") -if (LIBSTDCPP_PATH AND NOT LIBSTDCPP_PATH STREQUAL "libstdc++.so.6") - # A bare "libstdc++.so.6" means the compiler did not return a concrete path. - get_filename_component(LIBSTDCPP_DIR "${LIBSTDCPP_PATH}" DIRECTORY) - if (LIBSTDCPP_DIR) - # Prepend compiler runtime directory so ctest picks the matching ABI first. - set(TEST_RUNTIME_ENV "LD_LIBRARY_PATH=${LIBSTDCPP_DIR}:$ENV{LD_LIBRARY_PATH}") - endif () -endif () - file(GLOB_RECURSE LAZYLLM_TEST_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/tests/*.cpp" ) @@ -51,9 +33,9 @@ foreach (test_src ${LAZYLLM_TEST_SOURCES}) ) # Attach runtime env per discovered case so each test gets the same loader path. - if (TEST_RUNTIME_ENV AND discovered_tests) + if (LAZYLLM_TEST_RUNTIME_ENV AND discovered_tests) set_tests_properties(${discovered_tests} PROPERTIES - ENVIRONMENT "${TEST_RUNTIME_ENV}" + ENVIRONMENT "${LAZYLLM_TEST_RUNTIME_ENV}" ) endif () endforeach () diff --git a/csrc/scripts/build_debug.sh b/csrc/scripts/build_debug.sh index d67bc3f6e..1bea62ee6 100644 --- a/csrc/scripts/build_debug.sh +++ b/csrc/scripts/build_debug.sh @@ -2,10 +2,10 @@ # Run at LazyLLM/. set -euo pipefail -cmake -S ./csrc -B ./build \ +cmake -S csrc -B build \ -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ -DCMAKE_BUILD_TYPE=Debug -cmake --build ./build +cmake --build build # Install into ./lazyllm (prefix=. + LIBRARY DESTINATION lazyllm). -cmake --install ./build --prefix . --component lazyllm_cpp +cmake --install build --prefix . --component lazyllm_cpp diff --git a/csrc/scripts/build_release.sh b/csrc/scripts/build_release.sh index 865d22dc0..91660620a 100644 --- a/csrc/scripts/build_release.sh +++ b/csrc/scripts/build_release.sh @@ -1,7 +1,11 @@ #!/usr/bin/env bash +# Run at LazyLLM/. set -euo pipefail cmake -S csrc -B build-release \ -Dpybind11_DIR="$(python -m pybind11 --cmakedir)" \ -DCMAKE_BUILD_TYPE=Release cmake --build build-release + +# Install into ./lazyllm (prefix=. + LIBRARY DESTINATION lazyllm). +cmake --install build --prefix . --component lazyllm_cpp diff --git a/lazyllm/cpp.py b/lazyllm/cpp.py index cd292d3d4..a40094fab 100644 --- a/lazyllm/cpp.py +++ b/lazyllm/cpp.py @@ -1,37 +1,32 @@ import importlib -import logging import os +import sys +import ctypes from typing import Dict, Iterable -LOG = logging.getLogger(__name__) - -_CPP_ENABLE_ENV = 'LAZYLLM_ENABLE_CPP_OVERRIDE' +_LAZYLLM_CPP_MODULE = None +_LAZYLLM_CPP_ENABLED = None def _is_enabled() -> bool: - value = os.getenv(_CPP_ENABLE_ENV) - return value is not None and (value == '1' or value.lower() == 'true') - - -def _load_cpp_module(): - try: - return importlib.import_module('lazyllm.lazyllm_cpp') - except ImportError as e: - LOG.warning('C++ override is enabled but lazyllm_cpp import failed: %s', e) - return None + global _LAZYLLM_CPP_ENABLED + if _LAZYLLM_CPP_ENABLED is None: + value = os.getenv('LAZYLLM_ENABLE_CPP_OVERRIDE') + _LAZYLLM_CPP_ENABLED = value is not None and (value == '1' or value.lower() == 'true') + return _LAZYLLM_CPP_ENABLED def override_with_cpp_exports(module_globals: Dict[str, object], names: Iterable[str]): if not _is_enabled(): return - cpp_module = _load_cpp_module() - if cpp_module is None: - return + global _LAZYLLM_CPP_MODULE + if _LAZYLLM_CPP_MODULE is None: + _LAZYLLM_CPP_MODULE = importlib.import_module('lazyllm.lazyllm_cpp') missing = object() for name in names: - cpp_export = getattr(cpp_module, name, missing) + cpp_export = getattr(_LAZYLLM_CPP_MODULE, name, missing) if cpp_export is missing: raise AttributeError(f"module 'lazyllm.lazyllm_cpp' has no attribute '{name}'") module_globals[name] = cpp_export From fa73e50ea6c20865f62beb674c4173afb09fe98d Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 12 Mar 2026 14:43:57 +0800 Subject: [PATCH 38/40] feat: add cpp_class decorator for C++ class replacement --- lazyllm/cpp.py | 31 ++++++----- lazyllm/tools/rag/__init__.py | 4 -- lazyllm/tools/rag/doc_node.py | 2 + lazyllm/tools/rag/transform/base.py | 8 +-- lazyllm/tools/rag/transform/sentence.py | 2 + tests/test_cpp_class_decorator.py | 71 +++++++++++++++++++++++++ 6 files changed, 97 insertions(+), 21 deletions(-) create mode 100644 tests/test_cpp_class_decorator.py diff --git a/lazyllm/cpp.py b/lazyllm/cpp.py index a40094fab..74ae455a2 100644 --- a/lazyllm/cpp.py +++ b/lazyllm/cpp.py @@ -1,11 +1,10 @@ import importlib import os -import sys -import ctypes -from typing import Dict, Iterable +from typing import TypeVar, cast _LAZYLLM_CPP_MODULE = None _LAZYLLM_CPP_ENABLED = None +_C = TypeVar('_C', bound=type) def _is_enabled() -> bool: @@ -16,17 +15,23 @@ def _is_enabled() -> bool: return _LAZYLLM_CPP_ENABLED -def override_with_cpp_exports(module_globals: Dict[str, object], names: Iterable[str]): - if not _is_enabled(): - return - +def _load_cpp_module(): global _LAZYLLM_CPP_MODULE if _LAZYLLM_CPP_MODULE is None: _LAZYLLM_CPP_MODULE = importlib.import_module('lazyllm.lazyllm_cpp') + return _LAZYLLM_CPP_MODULE + + +def cpp_class(py_class: _C) -> _C: + if not isinstance(py_class, type): + raise TypeError(f'@cpp_class can only decorate classes, got: {type(py_class).__name__}') + + if not _is_enabled(): + return py_class + + cpp_module = _load_cpp_module() + export_name = py_class.__name__ + + cpp_export = getattr(cpp_module, export_name) - missing = object() - for name in names: - cpp_export = getattr(_LAZYLLM_CPP_MODULE, name, missing) - if cpp_export is missing: - raise AttributeError(f"module 'lazyllm.lazyllm_cpp' has no attribute '{name}'") - module_globals[name] = cpp_export + return cast(_C, cpp_export) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 48ca90c6a..ec4489dab 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -79,7 +79,3 @@ 'JSONLSplitter', 'SchemaExtractor' ] - -from lazyllm.cpp import override_with_cpp_exports # noqa E402 -override_with_cpp_exports(globals(), ['DocNode', 'NodeTransform', 'SentenceSplitter']) -del override_with_cpp_exports diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 78874eb1e..b73c3b4d4 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -3,6 +3,7 @@ from collections import defaultdict from lazyllm.thirdparty import PIL from lazyllm import JsonFormatter, config, reset_on_pickle, Mode, LOG +from lazyllm.cpp import cpp_class from lazyllm.components.utils.file_operate import _image_to_base64 from .global_metadata import RAG_DOC_ID, RAG_DOC_PATH, RAG_KB_ID import uuid @@ -22,6 +23,7 @@ class MetadataMode(str, Enum): NONE = auto() +@cpp_class @reset_on_pickle(('_lock', threading.Lock)) class DocNode: def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[Any]]] = None, diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index d6d5e490e..42228592d 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -16,11 +16,13 @@ import threading from lazyllm.thirdparty import tiktoken from lazyllm import config, ModuleBase +from lazyllm.cpp import cpp_class from pathlib import Path import inspect from lazyllm.thirdparty import nltk from lazyllm.thirdparty import transformers +@cpp_class class MetadataMode(str, Enum): ALL = 'ALL' EMBED = 'EMBED' @@ -61,6 +63,7 @@ def split_text_keep_separator(text: str, separator: str) -> List[str]: return result +@cpp_class class NodeTransform(ModuleBase): __support_rich__ = False @@ -206,6 +209,7 @@ def _forward_single(self, node: Union[DocNode, RichDocNode], **kwargs: Any) -> L _UNSET = object() +@cpp_class class _TextSplitterBase(NodeTransform): _default_params = {} _default_params_lock = threading.RLock() @@ -741,7 +745,3 @@ def __len__(self) -> int: def __bool__(self) -> bool: return bool(self._rules) - -from lazyllm.cpp import override_with_cpp_exports # noqa E402 -override_with_cpp_exports(globals(), ['_TextSplitterBase']) -del override_with_cpp_exports diff --git a/lazyllm/tools/rag/transform/sentence.py b/lazyllm/tools/rag/transform/sentence.py index 2f1729323..086083396 100644 --- a/lazyllm/tools/rag/transform/sentence.py +++ b/lazyllm/tools/rag/transform/sentence.py @@ -1,6 +1,8 @@ from typing import List, Tuple from .base import _TextSplitterBase, _Split, _UNSET +from lazyllm.cpp import cpp_class +@cpp_class class SentenceSplitter(_TextSplitterBase): def __init__(self, chunk_size: int = _UNSET, chunk_overlap: int = _UNSET, num_workers: int = _UNSET): super().__init__(chunk_size=chunk_size, overlap=chunk_overlap, num_workers=num_workers) diff --git a/tests/test_cpp_class_decorator.py b/tests/test_cpp_class_decorator.py new file mode 100644 index 000000000..dfc913e35 --- /dev/null +++ b/tests/test_cpp_class_decorator.py @@ -0,0 +1,71 @@ +import importlib +import pytest +from types import SimpleNamespace + + +def _reload_cpp_module(): + import lazyllm.cpp as cpp + return importlib.reload(cpp) + + +def test_cpp_class_keeps_python_class_when_disabled(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '0') + cpp = _reload_cpp_module() + + class PyOnly: + pass + + replaced = cpp.cpp_class(PyOnly) + assert replaced is PyOnly + + +def test_cpp_class_replaces_with_cpp_export_when_enabled(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + + class CppDummy: + pass + + monkeypatch.setattr(cpp, '_load_cpp_module', lambda: SimpleNamespace(Dummy=CppDummy)) + + class Dummy: + pass + + replaced = cpp.cpp_class(Dummy) + assert replaced is CppDummy + + +def test_cpp_class_rejects_non_class_object(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + + with pytest.raises(TypeError, match='can only decorate classes'): + cpp.cpp_class('NotAClass') + + +def test_cpp_class_raises_when_cpp_export_missing(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + monkeypatch.setattr(cpp, '_load_cpp_module', lambda: SimpleNamespace()) + + class Missing: + pass + + with pytest.raises(AttributeError, match="has no attribute 'Missing'"): + cpp.cpp_class(Missing) + + +def test_cpp_class_propagates_import_error_when_enabled(monkeypatch): + monkeypatch.setenv('LAZYLLM_ENABLE_CPP_OVERRIDE', '1') + cpp = _reload_cpp_module() + + def _boom(): + raise ImportError('boom') + + monkeypatch.setattr(cpp, '_load_cpp_module', _boom) + + class AnyClass: + pass + + with pytest.raises(ImportError, match='boom'): + cpp.cpp_class(AnyClass) From 08f3333f6c017f8a0bb75a1e230d76ceaf1b670a Mon Sep 17 00:00:00 2001 From: yzh Date: Thu, 12 Mar 2026 16:48:27 +0800 Subject: [PATCH 39/40] docnode cpp ext repaired --- csrc/binding/binding_utils.cpp | 91 +++++++++++ csrc/binding/binding_utils.hpp | 20 +++ csrc/binding/export_doc_node.cpp | 254 +++++++++++++++++++++++-------- csrc/core/include/doc_node.hpp | 3 +- csrc/core/include/utils.hpp | 3 + csrc/core/src/utils.cpp | 29 ++++ 6 files changed, 339 insertions(+), 61 deletions(-) create mode 100644 csrc/binding/binding_utils.cpp create mode 100644 csrc/binding/binding_utils.hpp diff --git a/csrc/binding/binding_utils.cpp b/csrc/binding/binding_utils.cpp new file mode 100644 index 000000000..7030be59a --- /dev/null +++ b/csrc/binding/binding_utils.cpp @@ -0,0 +1,91 @@ +#include "binding_utils.hpp" + +namespace lazyllm::pybind_utils { + +std::string DumpJson(const py::object& obj) { + py::object json = py::module_::import("json"); + py::object dumps = json.attr("dumps"); + py::object dumped = dumps(obj, py::arg("ensure_ascii") = false); + return dumped.cast(); +} + +py::object LoadJson(const std::string& text) { + py::object json = py::module_::import("json"); + py::object loads = json.attr("loads"); + return loads(py::str(text)); +} + +bool ExtractStringSequence(const py::object& obj, std::vector* out) { + if (!py::isinstance(obj) || py::isinstance(obj)) return false; + py::sequence seq = obj.cast(); + out->clear(); + out->reserve(seq.size()); + for (py::handle item : seq) { + if (!py::isinstance(item)) { + out->clear(); + return false; + } + out->push_back(py::cast(item)); + } + return true; +} + +lazyllm::MetadataMode ParseMetadataMode(const py::object& mode) { + if (mode.is_none()) return lazyllm::MetadataMode::NONE; + try { + if (py::hasattr(mode, "name")) { + const auto name = py::cast(mode.attr("name")); + if (name == "ALL") return lazyllm::MetadataMode::ALL; + if (name == "EMBED") return lazyllm::MetadataMode::EMBED; + if (name == "LLM") return lazyllm::MetadataMode::LLM; + if (name == "NONE") return lazyllm::MetadataMode::NONE; + } + } catch (const py::error_already_set&) { + } + if (py::isinstance(mode)) { + const auto name = mode.cast(); + if (name == "ALL") return lazyllm::MetadataMode::ALL; + if (name == "EMBED") return lazyllm::MetadataMode::EMBED; + if (name == "LLM") return lazyllm::MetadataMode::LLM; + if (name == "NONE") return lazyllm::MetadataMode::NONE; + } + if (py::isinstance(mode)) { + const auto value = mode.cast(); + switch (value) { + case 0: return lazyllm::MetadataMode::ALL; + case 1: return lazyllm::MetadataMode::EMBED; + case 2: return lazyllm::MetadataMode::LLM; + case 3: return lazyllm::MetadataMode::NONE; + default: break; + } + } + return lazyllm::MetadataMode::NONE; +} + +std::any PyToAny(const py::handle& value) { + if (value.is_none()) return std::string("None"); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + return py::str(value).cast(); +} + +py::object AnyToPy(const std::any& value) { + const auto& t = value.type(); + if (t == typeid(std::string)) return py::str(std::any_cast(value)); + if (t == typeid(const char*)) return py::str(std::any_cast(value)); + if (t == typeid(char*)) return py::str(std::any_cast(value)); + if (t == typeid(bool)) return py::bool_(std::any_cast(value)); + if (t == typeid(int)) return py::int_(std::any_cast(value)); + if (t == typeid(long)) return py::int_(std::any_cast(value)); + if (t == typeid(long long)) return py::int_(std::any_cast(value)); + if (t == typeid(unsigned int)) return py::int_(std::any_cast(value)); + if (t == typeid(unsigned long)) return py::int_(std::any_cast(value)); + if (t == typeid(unsigned long long)) return py::int_(std::any_cast(value)); + if (t == typeid(float)) return py::float_(std::any_cast(value)); + if (t == typeid(double)) return py::float_(std::any_cast(value)); + return py::str(""); +} + +} // namespace lazyllm::pybind_utils diff --git a/csrc/binding/binding_utils.hpp b/csrc/binding/binding_utils.hpp new file mode 100644 index 000000000..8aee4722e --- /dev/null +++ b/csrc/binding/binding_utils.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include +#include + +#include "lazyllm.hpp" +#include "doc_node.hpp" + +namespace lazyllm::pybind_utils { + +std::string DumpJson(const py::object& obj); +py::object LoadJson(const std::string& text); +bool ExtractStringSequence(const py::object& obj, std::vector* out); +lazyllm::MetadataMode ParseMetadataMode(const py::object& mode); +std::any PyToAny(const py::handle& value); +py::object AnyToPy(const std::any& value); + +} // namespace lazyllm::pybind_utils diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 1eb61804f..383aed232 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -6,24 +6,86 @@ #include "lazyllm.hpp" #include "document_store.hpp" #include "doc_node.hpp" -#include "utils.hpp" +#include "binding_utils.hpp" namespace { +namespace pyu = lazyllm::pybind_utils; + +bool IsJsonDocNode(const py::object& self) { + try { + const auto name = py::cast(py::type::of(self).attr("__name__")); + return name == "JsonDocNode"; + } catch (const py::error_already_set&) { + return false; + } +} + +lazyllm::DocNode::Metadata MetadataFromPy(const py::object& obj) { + lazyllm::DocNode::Metadata out; + if (obj.is_none()) return out; + py::dict d = py::dict(obj); + out.reserve(d.size()); + for (auto item : d) { + const std::string key = py::cast(item.first); + out.emplace(key, pyu::PyToAny(item.second)); + } + return out; +} + +py::dict MetadataToPy(const lazyllm::DocNode::Metadata& meta) { + py::dict d; + for (const auto& [key, value] : meta) { + d[py::str(key)] = pyu::AnyToPy(value); + } + return d; +} + +using NodeGroups = std::unordered_map>; + +std::optional NodeGroupsFromPy(const py::object& obj) { + if (obj.is_none()) return std::nullopt; + py::dict d = py::dict(obj); + NodeGroups out; + out.reserve(d.size()); + for (auto item : d) { + const std::string key = py::cast(item.first); + py::object group_obj = py::reinterpret_borrow(item.second); + py::dict group_dict = py::dict(group_obj); + std::unordered_map inner; + inner.reserve(group_dict.size()); + for (auto kv : group_dict) { + const std::string inner_key = py::cast(kv.first); + inner.emplace(inner_key, pyu::PyToAny(kv.second)); + } + out.emplace(key, std::move(inner)); + } + return out; +} + +std::optional>> NormalizeContent( + const py::object& content +) { + if (content.is_none()) return std::nullopt; + if (py::isinstance(content)) return content.cast(); + std::vector texts; + if (pyu::ExtractStringSequence(content, &texts)) return texts; + return pyu::DumpJson(content); +} + lazyllm::DocNode init( std::optional uid, - std::optional>> content, + py::object content, std::optional group, std::optional embedding, - std::optional> parent, + py::object parent, py::object store, - std::optional>> node_groups, - std::optional metadata, - std::optional global_metadata, - std::optional text + py::object node_groups, + py::object metadata, + py::object global_metadata, + py::object text ) { - if (content && text) + if (!content.is_none() && !text.is_none()) throw std::invalid_argument("`text` and `content` cannot be set at the same time."); lazyllm::DocNode* p_parent_node = nullptr; @@ -31,10 +93,16 @@ lazyllm::DocNode init( // Build node groups map. // Usually, parent + store + node_groups are not None at the same time. - if (parent && !store.is_none() && node_groups && global_metadata && group) { + const auto node_groups_opt = NodeGroupsFromPy(node_groups); + const auto metadata_map = MetadataFromPy(metadata); + const auto global_metadata_map = MetadataFromPy(global_metadata); + const bool has_parent = !parent.is_none(); + const bool has_store_context = has_parent && !store.is_none() + && node_groups_opt.has_value() && !global_metadata.is_none() && group.has_value(); + if (has_store_context) { std::unordered_map node_groups_map; - node_groups_map.reserve(node_groups->size()); - for (const auto& [group_key, group_dict] : *node_groups) { + node_groups_map.reserve(node_groups_opt->size()); + for (const auto& [group_key, group_dict] : *node_groups_opt) { node_groups_map.emplace(group_key, lazyllm::NodeGroup( std::any_cast(group_dict.at(std::string("parent"))), std::any_cast(group_dict.at(std::string("display_name"))) @@ -42,44 +110,48 @@ lazyllm::DocNode init( } store_adaptor = lazyllm::DocumentStore::from_store(store, node_groups_map); - auto kb_id = std::any_cast((*global_metadata).at( + auto kb_id = std::any_cast(global_metadata_map.at( std::string(lazyllm::RAGMetadataKeys::KB_ID))); - auto doc_id = std::any_cast((*global_metadata).at( + auto doc_id = std::any_cast(global_metadata_map.at( std::string(lazyllm::RAGMetadataKeys::DOC_ID))); - if (const auto* parent_uid = std::get_if(&*parent)) { + if (py::isinstance(parent)) { + const auto parent_uid = parent.cast(); p_parent_node = std::any_cast(store_adaptor->call("get_node", - {{"group_name", *group}, {"uid", *parent_uid}, {"kb_id", kb_id}})); + {{"group_name", *group}, {"uid", parent_uid}, {"kb_id", kb_id}})); + } else { + p_parent_node = parent.cast(); } - else - p_parent_node = std::get(*parent).cast(); + } else if (has_parent && !py::isinstance(parent)) { + p_parent_node = parent.cast(); } std::string raw_text; - lazyllm::DocNode node( "", group.value_or(""), uid.value_or(""), p_parent_node, - metadata.value_or(lazyllm::DocNode::Metadata()), - std::make_shared( - global_metadata.value_or(lazyllm::DocNode::Metadata())) + metadata_map, + std::make_shared(global_metadata_map) ); if (store_adaptor) node.set_store(store_adaptor); if (embedding) { for (const auto& [key, vec] : *embedding) node.set_embedding_vec(key, vec); } - if (content) { - if (const std::string* s = std::get_if(&*content)) - node.set_root_text(std::move(*s)); - else - node.set_root_texts(std::get>(*content)); + if (!content.is_none()) { + const auto normalized = NormalizeContent(content); + if (normalized) { + if (const auto* s = std::get_if(&*normalized)) + node.set_root_text(std::move(*s)); + else + node.set_root_texts(std::get>(*normalized)); + } } - else if (text){ - node.set_root_text(std::move(*text)); + else if (!text.is_none()){ + node.set_root_text(text.cast()); } return node; @@ -110,9 +182,8 @@ void exportDocNode(py::module& m) { .value("LLM", lazyllm::MetadataMode::LLM) .value("NONE", lazyllm::MetadataMode::NONE); - py::class_>(m, "DocNode") + py::class_>(m, "DocNode", py::dynamic_attr()) .def(py::init(&init), - py::kw_only(), py::arg("uid") = py::none(), py::arg("content") = py::none(), py::arg("group") = py::none(), @@ -130,16 +201,37 @@ void exportDocNode(py::module& m) { [](const lazyllm::DocNode& node) { return std::string(node.get_text(lazyllm::MetadataMode::NONE)); }, - [](lazyllm::DocNode& node, const std::variant>& content) { - if (const auto* content_str = std::get_if(&content)) { + [](lazyllm::DocNode& node, const py::object& content) { + const auto normalized = NormalizeContent(content); + if (!normalized) return; + if (const auto* content_str = std::get_if(&*normalized)) { node.set_root_text(std::move(*content_str)); return; } - else { - node.set_root_texts(std::get>(content)); + node.set_root_texts(std::get>(*normalized)); + } + ) + .def_property("_content", + py::cpp_function([](const py::object& self) { + const auto& node = self.cast(); + const std::string text = std::string(node.get_text(lazyllm::MetadataMode::NONE)); + if (!IsJsonDocNode(self)) return py::cast(text); + try { + return pyu::LoadJson(text); + } catch (const py::error_already_set&) { + return py::cast(text); + } + }), + py::cpp_function([](py::object self, const py::object& content) { + auto& node = self.cast(); + const auto normalized = NormalizeContent(content); + if (!normalized) return; + if (const auto* content_str = std::get_if(&*normalized)) { + node.set_root_text(std::move(*content_str)); return; } - } + node.set_root_texts(std::get>(*normalized)); + }) ) .def_property("number", [](const lazyllm::DocNode& node) { @@ -163,7 +255,17 @@ void exportDocNode(py::module& m) { ) .def_property("parent", [](const lazyllm::DocNode& node) { return node.get_parent_node(); }, - [](lazyllm::DocNode& node, lazyllm::DocNode* parent) { node.set_parent_node(parent); }, + [](lazyllm::DocNode& node, const py::object& parent) { + if (parent.is_none()) { + node.set_parent_node(nullptr); + return; + } + if (py::isinstance(parent)) { + node.set_parent_node(nullptr); + return; + } + node.set_parent_node(parent.cast()); + }, py::return_value_policy::reference ) .def_property("children", @@ -180,24 +282,39 @@ void exportDocNode(py::module& m) { [](const lazyllm::DocNode& node) { return node.get_parent_node() == nullptr; } ) .def_property("global_metadata", - [](const lazyllm::DocNode& node) { return *(node.get_root_node()->_p_global_metadata); }, - [](lazyllm::DocNode& node, const lazyllm::DocNode::Metadata& meta) { - node._p_global_metadata = std::make_shared(meta); + [](const lazyllm::DocNode& node) { return MetadataToPy(*(node.get_root_node()->_p_global_metadata)); }, + [](lazyllm::DocNode& node, const py::object& meta) { + node._p_global_metadata = std::make_shared( + MetadataFromPy(meta)); } ) .def_property("metadata", - [](const lazyllm::DocNode& node) { return node._metadata; }, - [](lazyllm::DocNode& node, const lazyllm::DocNode::Metadata& meta) { node._metadata = meta; } + [](const lazyllm::DocNode& node) { return MetadataToPy(node._metadata); }, + [](lazyllm::DocNode& node, const py::object& meta) { node._metadata = MetadataFromPy(meta); } ) .def_property("excluded_embed_metadata_keys", - [](const lazyllm::DocNode& node) { return node.get_excluded_embed_metadata_keys(); }, - [](lazyllm::DocNode& node, const std::set& keys) { + [](const lazyllm::DocNode& node) { + const auto keys = node.get_excluded_embed_metadata_keys(); + return std::vector(keys.begin(), keys.end()); + }, + [](lazyllm::DocNode& node, const py::object& keys_obj) { + std::set keys; + for (auto item : keys_obj) { + keys.insert(py::cast(item)); + } node.set_excluded_embed_metadata_keys(keys); } ) .def_property("excluded_llm_metadata_keys", - [](const lazyllm::DocNode& node) { return node.get_excluded_llm_metadata_keys(); }, - [](lazyllm::DocNode& node, const std::set& keys) { + [](const lazyllm::DocNode& node) { + const auto keys = node.get_excluded_llm_metadata_keys(); + return std::vector(keys.begin(), keys.end()); + }, + [](lazyllm::DocNode& node, const py::object& keys_obj) { + std::set keys; + for (auto item : keys_obj) { + keys.insert(py::cast(item)); + } node.set_excluded_llm_metadata_keys(keys); } ) @@ -249,24 +366,37 @@ void exportDocNode(py::module& m) { st["_content"] = node.get_text(lazyllm::MetadataMode::NONE); st["_group"] = node._group_name; st["_embedding"] = node._embedding_vecs; - st["_metadata"] = node._metadata; - st["_global_metadata"] = *(node._p_global_metadata); + st["_metadata"] = MetadataToPy(node._metadata); + st["_global_metadata"] = MetadataToPy(*(node._p_global_metadata)); st["_excluded_embed_metadata_keys"] = node.get_excluded_embed_metadata_keys(); st["_excluded_llm_metadata_keys"] = node.get_excluded_llm_metadata_keys(); st["_store"] = py::none(); st["_node_groups"] = py::none(); return st; }) - .def("has_missing_embedding", [](const lazyllm::DocNode& node, - std::variant>& keys) { - if (const auto& single_key = std::get_if(&keys)) - return node.embedding_keys_undone({*single_key}); - else { - const auto& key_list = std::get>(keys); - return node.embedding_keys_undone(std::set(key_list.begin(), key_list.end())); + .def("has_missing_embedding", [](const lazyllm::DocNode& node, const py::object& keys) { + if (py::isinstance(keys)) { + const auto key = keys.cast(); + const auto missing = node.embedding_keys_undone({key}); + return std::vector(missing.begin(), missing.end()); } + std::set key_set; + for (auto item : keys) { + key_set.insert(py::cast(item)); + } + const auto missing = node.embedding_keys_undone(key_set); + return std::vector(missing.begin(), missing.end()); }) - .def("do_embedding", &lazyllm::DocNode::py_do_embedding, py::arg("embed")) + .def("do_embedding", [](lazyllm::DocNode& node, const py::object& embed) { + py::dict embed_dict = py::dict(embed); + const auto text = node.get_text(lazyllm::MetadataMode::EMBED); + for (auto item : embed_dict) { + const std::string key = py::cast(item.first); + py::object func = py::reinterpret_borrow(item.second); + py::object result = func(py::str(text)); + node.set_embedding_vec(key, result.cast>()); + } + }, py::arg("embed")) .def("set_embedding", [](lazyllm::DocNode& node, const std::string& key, const std::vector& value) { node.set_embedding_vec(key, value); }) @@ -280,14 +410,18 @@ void exportDocNode(py::module& m) { } }) .def("get_content", [](const lazyllm::DocNode& node) { return node.get_text(lazyllm::MetadataMode::LLM); }) - .def("get_metadata_str", &lazyllm::DocNode::get_metadata_string, - py::arg("mode") = lazyllm::MetadataMode::ALL) - .def("get_text", &lazyllm::DocNode::get_text, py::arg("metadata_mode") = lazyllm::MetadataMode::NONE) + .def("get_metadata_str", [](const lazyllm::DocNode& node, const py::object& mode) { + if (mode.is_none()) return node.get_metadata_string(lazyllm::MetadataMode::ALL); + return node.get_metadata_string(pyu::ParseMetadataMode(mode)); + }, py::arg("mode") = py::none()) + .def("get_text", [](const lazyllm::DocNode& node, const py::object& metadata_mode) { + return node.get_text(pyu::ParseMetadataMode(metadata_mode)); + }, py::arg("metadata_mode") = py::none()) .def("to_dict", [](const lazyllm::DocNode& node) { py::dict d; d["content"] = node.get_text(lazyllm::MetadataMode::NONE); d["embedding"] = node._embedding_vecs; - d["metadata"] = node._metadata; + d["metadata"] = MetadataToPy(node._metadata); return d; }) .def("with_score", [](const lazyllm::DocNode& node, double score) { diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index b06381cc3..e86b2de97 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -113,13 +113,14 @@ class DocNode { std::vector kv_strings; for (const std::string& key : valid_keys) - kv_strings.emplace_back(key + ":" + std::any_cast(_metadata.at(key))); + kv_strings.emplace_back(key + ": " + any_to_string(_metadata.at(key))); return JoinLines(kv_strings); } std::string get_text(MetadataMode mode = MetadataMode::NONE) const { if (mode == MetadataMode::NONE) return std::string(_text_view); const auto& metadata_string = get_metadata_string(mode); + if (metadata_string.empty()) return std::string(_text_view); return metadata_string + "\n\n" + std::string(_text_view); } void set_root_text(const std::string& text) { diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 42dfe8b0c..023575555 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace lazyllm { @@ -80,6 +81,8 @@ inline std::string GenerateUUID() { return out; } +std::string any_to_string(const std::any& value); + inline bool is_adjacent(const std::string_view& left, const std::string_view& right) { return left.data() + left.size() == right.data(); } diff --git a/csrc/core/src/utils.cpp b/csrc/core/src/utils.cpp index 0ee624c51..f6dd7736c 100644 --- a/csrc/core/src/utils.cpp +++ b/csrc/core/src/utils.cpp @@ -1 +1,30 @@ #include "utils.hpp" + +namespace lazyllm { + +std::string any_to_string(const std::any& value) { + const auto& t = value.type(); + if (t == typeid(std::string)) return std::any_cast(value); + if (t == typeid(const char*)) return std::string(std::any_cast(value)); + if (t == typeid(char*)) return std::string(std::any_cast(value)); + if (t == typeid(bool)) return std::any_cast(value) ? "True" : "False"; + if (t == typeid(int)) return std::to_string(std::any_cast(value)); + if (t == typeid(long)) return std::to_string(std::any_cast(value)); + if (t == typeid(long long)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned long)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned long long)) return std::to_string(std::any_cast(value)); + if (t == typeid(float)) { + std::ostringstream oss; + oss << std::any_cast(value); + return oss.str(); + } + if (t == typeid(double)) { + std::ostringstream oss; + oss << std::any_cast(value); + return oss.str(); + } + return ""; +} + +} \ No newline at end of file From 2c893dfecaf950073cbb3572457e2324a63afca6 Mon Sep 17 00:00:00 2001 From: yzh Date: Fri, 13 Mar 2026 17:21:22 +0800 Subject: [PATCH 40/40] save --- csrc/adaptor/document_store.hpp | 4 +- csrc/binding/binding_utils.cpp | 69 ++- csrc/binding/binding_utils.hpp | 29 ++ csrc/binding/export_doc_node.cpp | 349 +++++++++++-- csrc/binding/export_node_transform.cpp | 355 ++++++++++++- csrc/binding/export_sentence_splitter.cpp | 201 ++++++- csrc/binding/export_text_splitter_base.cpp | 579 ++++++++++++++++----- csrc/core/include/doc_node.hpp | 15 +- csrc/core/include/node_transform.hpp | 1 + csrc/core/include/text_splitter_base.hpp | 3 + csrc/core/include/tokenizer.hpp | 100 +++- csrc/core/include/utils.hpp | 10 +- csrc/core/src/utils.cpp | 63 ++- csrc/tests/test_doc_node.cpp | 10 +- csrc/tests/test_utils.cpp | 18 + 15 files changed, 1572 insertions(+), 234 deletions(-) diff --git a/csrc/adaptor/document_store.hpp b/csrc/adaptor/document_store.hpp index 63bb6a6e4..1be72d066 100644 --- a/csrc/adaptor/document_store.hpp +++ b/csrc/adaptor/document_store.hpp @@ -55,8 +55,8 @@ class LAZYLLM_HIDDEN DocumentStore : public AdaptorBaseWrapper { DocNode::Children get_node_children(const DocNode* node) const { DocNode::Children out; - auto& kb_id = std::any_cast(node->_p_global_metadata->at(std::string(RAGMetadataKeys::KB_ID))); - auto& doc_id = std::any_cast(node->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_ID))); + auto& kb_id = std::get(node->_p_global_metadata->at(std::string(RAGMetadataKeys::KB_ID))); + auto& doc_id = std::get(node->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_ID))); auto& group_name = node->_group_name; for(auto& [current_group_name, group] : _node_groups_map) { if (group._parent != group_name) continue; diff --git a/csrc/binding/binding_utils.cpp b/csrc/binding/binding_utils.cpp index 7030be59a..cf5d82d67 100644 --- a/csrc/binding/binding_utils.cpp +++ b/csrc/binding/binding_utils.cpp @@ -1,5 +1,7 @@ #include "binding_utils.hpp" +#include + namespace lazyllm::pybind_utils { std::string DumpJson(const py::object& obj) { @@ -62,17 +64,80 @@ lazyllm::MetadataMode ParseMetadataMode(const py::object& mode) { return lazyllm::MetadataMode::NONE; } +lazyllm::DocNode::MetadataVType PyToMetadataValue(const py::handle& value) { + if (value.is_none()) return std::any(py::none()); + if (py::isinstance(value)) return static_cast(value.cast()); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + if (py::isinstance(value)) return value.cast(); + + if (py::isinstance(value) && !py::isinstance(value)) { + py::sequence seq = value.cast(); + if (seq.empty()) return std::any(py::reinterpret_borrow(value)); + + bool all_str = true; + bool all_int = true; + bool all_numeric = true; + + for (py::handle item : seq) { + const bool is_str = py::isinstance(item); + const bool is_int = py::isinstance(item) && !py::isinstance(item); + const bool is_numeric = is_int || py::isinstance(item) || py::isinstance(item); + all_str = all_str && is_str; + all_int = all_int && is_int; + all_numeric = all_numeric && is_numeric; + } + + if (all_str) { + std::vector out; + out.reserve(seq.size()); + for (py::handle item : seq) out.push_back(py::cast(item)); + return out; + } + if (all_int) { + std::vector out; + out.reserve(seq.size()); + for (py::handle item : seq) out.push_back(py::cast(item)); + return out; + } + if (all_numeric) { + std::vector out; + out.reserve(seq.size()); + for (py::handle item : seq) out.push_back(py::cast(item)); + return out; + } + } + return std::any(py::reinterpret_borrow(value)); +} + +py::object MetadataValueToPy(const lazyllm::DocNode::MetadataVType& value) { + return std::visit([](const auto& v) -> py::object { + using T = std::decay_t; + if constexpr (std::is_same_v) return py::str(v); + if constexpr (std::is_same_v) return py::int_(v); + if constexpr (std::is_same_v) return py::float_(v); + if constexpr (std::is_same_v>) return py::cast(v); + if constexpr (std::is_same_v>) return py::cast(v); + if constexpr (std::is_same_v>) return py::cast(v); + if constexpr (std::is_same_v) return AnyToPy(v); + return py::none(); + }, value); +} + std::any PyToAny(const py::handle& value) { - if (value.is_none()) return std::string("None"); + if (value.is_none()) return py::none(); if (py::isinstance(value)) return value.cast(); if (py::isinstance(value)) return value.cast(); if (py::isinstance(value)) return value.cast(); if (py::isinstance(value)) return value.cast(); - return py::str(value).cast(); + return py::reinterpret_borrow(value); } py::object AnyToPy(const std::any& value) { const auto& t = value.type(); + if (!value.has_value()) return py::none(); + if (t == typeid(py::object)) return std::any_cast(value); + if (t == typeid(py::none)) return py::none(); if (t == typeid(std::string)) return py::str(std::any_cast(value)); if (t == typeid(const char*)) return py::str(std::any_cast(value)); if (t == typeid(char*)) return py::str(std::any_cast(value)); diff --git a/csrc/binding/binding_utils.hpp b/csrc/binding/binding_utils.hpp index 8aee4722e..c1d8252ae 100644 --- a/csrc/binding/binding_utils.hpp +++ b/csrc/binding/binding_utils.hpp @@ -14,7 +14,36 @@ std::string DumpJson(const py::object& obj); py::object LoadJson(const std::string& text); bool ExtractStringSequence(const py::object& obj, std::vector* out); lazyllm::MetadataMode ParseMetadataMode(const py::object& mode); +lazyllm::DocNode::MetadataVType PyToMetadataValue(const py::handle& value); +py::object MetadataValueToPy(const lazyllm::DocNode::MetadataVType& value); std::any PyToAny(const py::handle& value); py::object AnyToPy(const std::any& value); } // namespace lazyllm::pybind_utils + +namespace pybind11::detail { + +template <> +struct type_caster { +public: + PYBIND11_TYPE_CASTER(lazyllm::MetadataVType, _("MetadataVType")); + + bool load(handle src, bool) { + try { + value = lazyllm::pybind_utils::PyToMetadataValue(src); + return true; + } catch (const pybind11::error_already_set&) { + PyErr_Clear(); + return false; + } catch (...) { + return false; + } + } + + static handle cast(const lazyllm::MetadataVType& src, return_value_policy, handle) { + pybind11::object obj = lazyllm::pybind_utils::MetadataValueToPy(src); + return obj.release(); + } +}; + +} // namespace pybind11::detail diff --git a/csrc/binding/export_doc_node.cpp b/csrc/binding/export_doc_node.cpp index 383aed232..f43abb427 100644 --- a/csrc/binding/export_doc_node.cpp +++ b/csrc/binding/export_doc_node.cpp @@ -2,16 +2,45 @@ #include #include #include +#include +#include #include "lazyllm.hpp" #include "document_store.hpp" #include "doc_node.hpp" #include "binding_utils.hpp" +#include + +PYBIND11_MAKE_OPAQUE(lazyllm::DocNode::Metadata); +PYBIND11_MAKE_OPAQUE(lazyllm::DocNode::Children); namespace { namespace pyu = lazyllm::pybind_utils; +const std::unordered_map kDocNodeAttrAliases = { + {"uid", "_uid"}, + {"group", "_group"}, + {"content", "_content"}, + {"parent", "_parent"}, + {"children", "_children"}, + {"global_metadata", "_global_metadata"}, + {"metadata", "_metadata"}, + {"excluded_embed_metadata_keys", "_excluded_embed_metadata_keys"}, + {"excluded_llm_metadata_keys", "_excluded_llm_metadata_keys"}, +}; + +const std::unordered_set kDocNodeReadonlyAliases = { + "uid", + "group", +}; + +const char* ResolveDocNodeAlias(const std::string& attr_name) { + const auto it = kDocNodeAttrAliases.find(attr_name); + if (it == kDocNodeAttrAliases.end()) return nullptr; + return it->second; +} + bool IsJsonDocNode(const py::object& self) { try { const auto name = py::cast(py::type::of(self).attr("__name__")); @@ -28,7 +57,7 @@ lazyllm::DocNode::Metadata MetadataFromPy(const py::object& obj) { out.reserve(d.size()); for (auto item : d) { const std::string key = py::cast(item.first); - out.emplace(key, pyu::PyToAny(item.second)); + out.emplace(key, pyu::PyToMetadataValue(item.second)); } return out; } @@ -36,11 +65,66 @@ lazyllm::DocNode::Metadata MetadataFromPy(const py::object& obj) { py::dict MetadataToPy(const lazyllm::DocNode::Metadata& meta) { py::dict d; for (const auto& [key, value] : meta) { - d[py::str(key)] = pyu::AnyToPy(value); + d[py::str(key)] = pyu::MetadataValueToPy(value); + } + return d; +} + +lazyllm::DocNode::Metadata& EnsureRootGlobalMetadata(lazyllm::DocNode& node) { + auto* root = const_cast(node.get_root_node()); + if (!root->_p_global_metadata) { + root->_p_global_metadata = std::make_shared(); + } + return *(root->_p_global_metadata); +} + +const lazyllm::DocNode::Metadata& GetRootGlobalMetadata(const lazyllm::DocNode& node) { + static const lazyllm::DocNode::Metadata empty; + auto* root = node.get_root_node(); + if (!root || !root->_p_global_metadata) return empty; + return *(root->_p_global_metadata); +} + +lazyllm::DocNode::Children ChildrenFromPy(const py::object& obj) { + lazyllm::DocNode::Children out; + if (obj.is_none()) return out; + py::dict d = py::dict(obj); + out.reserve(d.size()); + for (auto item : d) { + const std::string group = py::cast(item.first); + py::object nodes_obj = py::reinterpret_borrow(item.second); + std::vector nodes; + for (auto child_obj : nodes_obj) { + py::object child = py::reinterpret_borrow(child_obj); + if (child.is_none()) continue; + nodes.emplace_back(child.cast()); + } + out.emplace(group, std::move(nodes)); + } + return out; +} + +py::dict ChildrenToPy(const lazyllm::DocNode::Children& children) { + py::dict d; + for (const auto& [group, nodes] : children) { + py::list py_nodes; + for (const auto& n : nodes) py_nodes.append(n); + d[py::str(group)] = std::move(py_nodes); } return d; } +std::vector NodeVectorFromPy(const py::object& obj) { + std::vector out; + if (obj.is_none()) return out; + for (auto item : obj) { + py::object node = py::reinterpret_borrow(item); + if (node.is_none()) continue; + out.emplace_back(node.cast()); + } + return out; +} + using NodeGroups = std::unordered_map>; std::optional NodeGroupsFromPy(const py::object& obj) { @@ -110,9 +194,9 @@ lazyllm::DocNode init( } store_adaptor = lazyllm::DocumentStore::from_store(store, node_groups_map); - auto kb_id = std::any_cast(global_metadata_map.at( + auto kb_id = std::get(global_metadata_map.at( std::string(lazyllm::RAGMetadataKeys::KB_ID))); - auto doc_id = std::any_cast(global_metadata_map.at( + auto doc_id = std::get(global_metadata_map.at( std::string(lazyllm::RAGMetadataKeys::DOC_ID))); if (py::isinstance(parent)) { @@ -182,6 +266,162 @@ void exportDocNode(py::module& m) { .value("LLM", lazyllm::MetadataMode::LLM) .value("NONE", lazyllm::MetadataMode::NONE); + auto metadata_map = py::bind_map(m, "DocNodeMetadataMap"); + metadata_map + .def("get", + [](lazyllm::DocNode::Metadata& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it == self.end()) return default_value; + return pyu::MetadataValueToPy(it->second); + }, + py::arg("key"), + py::arg("default") = py::none()) + .def("pop", + [](lazyllm::DocNode::Metadata& self, const std::string& key) { + auto it = self.find(key); + if (it == self.end()) throw py::key_error(py::str(key)); + py::object value = pyu::MetadataValueToPy(it->second); + self.erase(it); + return value; + }, + py::arg("key")) + .def("pop", + [](lazyllm::DocNode::Metadata& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it == self.end()) return default_value; + py::object value = pyu::MetadataValueToPy(it->second); + self.erase(it); + return value; + }, + py::arg("key"), + py::arg("default")) + .def("setdefault", + [](lazyllm::DocNode::Metadata& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it != self.end()) return pyu::MetadataValueToPy(it->second); + auto [inserted, ok] = self.emplace(key, pyu::PyToMetadataValue(default_value)); + (void)ok; + return pyu::MetadataValueToPy(inserted->second); + }, + py::arg("key"), + py::arg("default") = py::none()) + .def("copy", + [](const lazyllm::DocNode::Metadata& self) { + return MetadataToPy(self); + }) + .def("__copy__", + [](const lazyllm::DocNode::Metadata& self) { + return MetadataToPy(self); + }) + .def("__deepcopy__", + [](const lazyllm::DocNode::Metadata& self, const py::dict& memo) { + py::object copy = py::module_::import("copy"); + return copy.attr("deepcopy")(MetadataToPy(self), memo); + }, + py::arg("memo")) + .def("update", + [](lazyllm::DocNode::Metadata& self, py::object other, py::kwargs kwargs) { + if (!other.is_none()) { + py::dict d = py::dict(other); + for (auto item : d) { + const std::string key = py::cast(item.first); + self[key] = pyu::PyToMetadataValue(item.second); + } + } + for (auto item : kwargs) { + const std::string key = py::cast(item.first); + self[key] = pyu::PyToMetadataValue(item.second); + } + }, + py::arg("other") = py::none()) + .def("__eq__", + [](const lazyllm::DocNode::Metadata& self, py::object other) { + py::dict lhs = MetadataToPy(self); + if (py::isinstance(other) || py::hasattr(other, "items")) + return py::bool_(lhs.equal(py::dict(other))); + return py::bool_(false); + }, + py::is_operator()); + + auto children_map = py::bind_map(m, "DocNodeChildrenMap"); + children_map + .def("get", + [](lazyllm::DocNode::Children& self, const std::string& key, py::object default_value) -> py::object { + auto it = self.find(key); + if (it == self.end()) return default_value; + return py::cast(it->second); + }, + py::arg("key"), + py::arg("default") = py::none()) + .def("pop", + [](lazyllm::DocNode::Children& self, const std::string& key) { + auto it = self.find(key); + if (it == self.end()) throw py::key_error(py::str(key)); + py::list value = py::cast(it->second); + self.erase(it); + return value; + }, + py::arg("key")) + .def("pop", + [](lazyllm::DocNode::Children& self, const std::string& key, py::object default_value) -> py::object { + auto it = self.find(key); + if (it == self.end()) return default_value; + py::list value = py::cast(it->second); + self.erase(it); + return value; + }, + py::arg("key"), + py::arg("default")) + .def("setdefault", + [](lazyllm::DocNode::Children& self, const std::string& key, py::object default_value) { + auto it = self.find(key); + if (it != self.end()) return py::cast(it->second); + auto [inserted, ok] = self.emplace(key, NodeVectorFromPy(default_value)); + (void)ok; + return py::cast(inserted->second); + }, + py::arg("key"), + py::arg("default") = py::list()) + .def("copy", + [](const lazyllm::DocNode::Children& self) { + return ChildrenToPy(self); + }) + .def("__copy__", + [](const lazyllm::DocNode::Children& self) { + return ChildrenToPy(self); + }) + .def("__deepcopy__", + [](const lazyllm::DocNode::Children& self, const py::dict& memo) { + py::object copy = py::module_::import("copy"); + return copy.attr("deepcopy")(ChildrenToPy(self), memo); + }, + py::arg("memo")) + .def("update", + [](lazyllm::DocNode::Children& self, py::object other, py::kwargs kwargs) { + if (!other.is_none()) { + py::dict d = py::dict(other); + for (auto item : d) { + const std::string key = py::cast(item.first); + py::object value = py::reinterpret_borrow(item.second); + self[key] = NodeVectorFromPy(value); + } + } + for (auto item : kwargs) { + const std::string key = py::cast(item.first); + py::object value = py::reinterpret_borrow(item.second); + self[key] = NodeVectorFromPy(value); + } + }, + py::arg("other") = py::none()) + .def("__eq__", + [](const lazyllm::DocNode::Children& self, py::object other) { + py::dict lhs = ChildrenToPy(self); + if (py::isinstance(other) || py::hasattr(other, "items")) + return py::bool_(lhs.equal(py::dict(other))); + return py::bool_(false); + }, + py::is_operator()); + py::class_>(m, "DocNode", py::dynamic_attr()) .def(py::init(&init), py::arg("uid") = py::none(), @@ -195,21 +435,13 @@ void exportDocNode(py::module& m) { py::arg("global_metadata") = py::none(), py::arg("text") = py::none() ) - .def_property_readonly("uid", &lazyllm::DocNode::get_uid) - .def_property_readonly("group", [](const lazyllm::DocNode& node) { return node._group_name; }) - .def_property("content", - [](const lazyllm::DocNode& node) { - return std::string(node.get_text(lazyllm::MetadataMode::NONE)); - }, - [](lazyllm::DocNode& node, const py::object& content) { - const auto normalized = NormalizeContent(content); - if (!normalized) return; - if (const auto* content_str = std::get_if(&*normalized)) { - node.set_root_text(std::move(*content_str)); - return; - } - node.set_root_texts(std::get>(*normalized)); - } + .def_property("_uid", + [](const lazyllm::DocNode& node) { return node.get_uid(); }, + [](lazyllm::DocNode& node, const std::string& value) { node.set_uid(value); } + ) + .def_property("_group", + [](const lazyllm::DocNode& node) { return node._group_name; }, + [](lazyllm::DocNode& node, const std::string& value) { node._group_name = value; } ) .def_property("_content", py::cpp_function([](const py::object& self) { @@ -235,9 +467,11 @@ void exportDocNode(py::module& m) { ) .def_property("number", [](const lazyllm::DocNode& node) { - const auto it = node._metadata.find("lazyllm_store_num"); + auto it = node._metadata.find("lazyllm_store_num"); if (it == node._metadata.end()) return 0; - return std::any_cast(it->second); + if (auto* value = std::get_if(&it->second)) return *value; + if (auto* value = std::get_if(&it->second)) return static_cast(*value); + return 0; }, [](lazyllm::DocNode& node, int value) { node._metadata[std::string("lazyllm_store_num")] = value; @@ -253,14 +487,10 @@ void exportDocNode(py::module& m) { node._embedding_vecs = v; } ) - .def_property("parent", + .def_property("_parent", [](const lazyllm::DocNode& node) { return node.get_parent_node(); }, [](lazyllm::DocNode& node, const py::object& parent) { - if (parent.is_none()) { - node.set_parent_node(nullptr); - return; - } - if (py::isinstance(parent)) { + if (parent.is_none() || py::isinstance(parent)) { node.set_parent_node(nullptr); return; } @@ -268,10 +498,13 @@ void exportDocNode(py::module& m) { }, py::return_value_policy::reference ) - .def_property("children", - [](const lazyllm::DocNode& node) { return node.get_children(); }, - [](lazyllm::DocNode& node, const lazyllm::DocNode::Children& children) { - node.set_children(children); + .def_property("_children", + py::cpp_function([](lazyllm::DocNode& node) -> lazyllm::DocNode::Children& { + node.get_children(); + return node.get_children_ref(); + }, py::return_value_policy::reference_internal), + [](lazyllm::DocNode& node, const py::object& children) { + node.set_children(ChildrenFromPy(children)); } ) .def_property_readonly("root_node", @@ -281,18 +514,23 @@ void exportDocNode(py::module& m) { .def_property_readonly("is_root_node", [](const lazyllm::DocNode& node) { return node.get_parent_node() == nullptr; } ) - .def_property("global_metadata", - [](const lazyllm::DocNode& node) { return MetadataToPy(*(node.get_root_node()->_p_global_metadata)); }, + .def_property("_global_metadata", + py::cpp_function([](lazyllm::DocNode& node) -> lazyllm::DocNode::Metadata& { + return EnsureRootGlobalMetadata(node); + }, py::return_value_policy::reference_internal), [](lazyllm::DocNode& node, const py::object& meta) { - node._p_global_metadata = std::make_shared( - MetadataFromPy(meta)); + EnsureRootGlobalMetadata(node) = MetadataFromPy(meta); } ) - .def_property("metadata", - [](const lazyllm::DocNode& node) { return MetadataToPy(node._metadata); }, - [](lazyllm::DocNode& node, const py::object& meta) { node._metadata = MetadataFromPy(meta); } + .def_property("_metadata", + py::cpp_function([](lazyllm::DocNode& node) -> lazyllm::DocNode::Metadata& { + return node._metadata; + }, py::return_value_policy::reference_internal), + [](lazyllm::DocNode& node, const py::object& meta) { + node._metadata = MetadataFromPy(meta); + } ) - .def_property("excluded_embed_metadata_keys", + .def_property("_excluded_embed_metadata_keys", [](const lazyllm::DocNode& node) { const auto keys = node.get_excluded_embed_metadata_keys(); return std::vector(keys.begin(), keys.end()); @@ -305,7 +543,7 @@ void exportDocNode(py::module& m) { node.set_excluded_embed_metadata_keys(keys); } ) - .def_property("excluded_llm_metadata_keys", + .def_property("_excluded_llm_metadata_keys", [](const lazyllm::DocNode& node) { const auto keys = node.get_excluded_llm_metadata_keys(); return std::vector(keys.begin(), keys.end()); @@ -318,6 +556,21 @@ void exportDocNode(py::module& m) { node.set_excluded_llm_metadata_keys(keys); } ) + .def("__getattr__", [](py::object self, const std::string& attr_name) -> py::object { + if (const char* alias = ResolveDocNodeAlias(attr_name)) { + return py::getattr(self, py::str(alias)); + } + throw py::attribute_error("DocNode has no attribute '" + attr_name + "'"); + }) + .def("__setattr__", [](py::object self, const std::string& attr_name, const py::object& value) { + if (kDocNodeReadonlyAliases.find(attr_name) != kDocNodeReadonlyAliases.end()) { + throw py::attribute_error("property '" + attr_name + "' of 'DocNode' object has no setter"); + } + const char* alias = ResolveDocNodeAlias(attr_name); + const char* target = alias ? alias : attr_name.c_str(); + py::object object_setattr = py::module_::import("builtins").attr("object").attr("__setattr__"); + object_setattr(self, py::str(target), value); + }) .def_property("docpath", [](const lazyllm::DocNode& node) { return node.get_doc_path(); }, [](lazyllm::DocNode& node, const std::string& path) { node.set_doc_path(path); } @@ -367,7 +620,7 @@ void exportDocNode(py::module& m) { st["_group"] = node._group_name; st["_embedding"] = node._embedding_vecs; st["_metadata"] = MetadataToPy(node._metadata); - st["_global_metadata"] = MetadataToPy(*(node._p_global_metadata)); + st["_global_metadata"] = MetadataToPy(GetRootGlobalMetadata(node)); st["_excluded_embed_metadata_keys"] = node.get_excluded_embed_metadata_keys(); st["_excluded_llm_metadata_keys"] = node.get_excluded_llm_metadata_keys(); st["_store"] = py::none(); @@ -409,15 +662,19 @@ void exportDocNode(py::module& m) { std::this_thread::sleep_for(std::chrono::seconds(1)); } }) - .def("get_content", [](const lazyllm::DocNode& node) { return node.get_text(lazyllm::MetadataMode::LLM); }) - .def("get_metadata_str", [](const lazyllm::DocNode& node, const py::object& mode) { + .def("get_content", [](py::object self) { + return self.cast().get_text(lazyllm::MetadataMode::LLM); + }) + .def("get_metadata_str", [](py::object self, const py::object& mode) { + auto& node = self.cast(); if (mode.is_none()) return node.get_metadata_string(lazyllm::MetadataMode::ALL); return node.get_metadata_string(pyu::ParseMetadataMode(mode)); }, py::arg("mode") = py::none()) - .def("get_text", [](const lazyllm::DocNode& node, const py::object& metadata_mode) { - return node.get_text(pyu::ParseMetadataMode(metadata_mode)); + .def("get_text", [](py::object self, const py::object& metadata_mode) { + return self.cast().get_text(pyu::ParseMetadataMode(metadata_mode)); }, py::arg("metadata_mode") = py::none()) - .def("to_dict", [](const lazyllm::DocNode& node) { + .def("to_dict", [](py::object self) { + auto& node = self.cast(); py::dict d; d["content"] = node.get_text(lazyllm::MetadataMode::NONE); d["embedding"] = node._embedding_vecs; diff --git a/csrc/binding/export_node_transform.cpp b/csrc/binding/export_node_transform.cpp index 99ed61eb1..e162c6d27 100644 --- a/csrc/binding/export_node_transform.cpp +++ b/csrc/binding/export_node_transform.cpp @@ -3,8 +3,15 @@ #include "doc_node.hpp" #include "node_transform.hpp" +#include + +#include +#include + +#include #include #include +#include namespace { @@ -15,6 +22,9 @@ class PyNodeTransform : public lazyllm::NodeTransform { std::vector transform(lazyllm::PDocNode node) const override { py::gil_scoped_acquire gil; py::function overload = py::get_override(static_cast(this), "transform"); + if (!overload) { + overload = py::get_override(static_cast(this), "forward"); + } if (!overload) throw std::runtime_error("NodeTransform.transform is not implemented."); py::object result = overload(node); @@ -32,23 +42,204 @@ class PyNodeTransform : public lazyllm::NodeTransform { } }; +py::object GetBaseModule() { + return py::module_::import("lazyllm.tools.rag.transform.base"); +} + +py::object GetRuleSetClass() { + return GetBaseModule().attr("RuleSet"); +} + +py::object GetContextClass() { + return GetBaseModule().attr("_Context"); +} + +py::object GetDocNodeClass() { + return py::module_::import("lazyllm.tools.rag.doc_node").attr("DocNode"); +} + +py::object GetRichDocNodeClass() { + return py::module_::import("lazyllm.tools.rag.doc_node").attr("RichDocNode"); +} + +bool IsDocNode(const py::object& obj) { + return py::isinstance(obj, GetDocNodeClass()); +} + +bool IsRichDocNode(const py::object& obj) { + return py::isinstance(obj, GetRichDocNodeClass()); +} + +struct TransformRuntimeState { + py::object rules = py::none(); + py::object on_match = py::none(); + py::object on_miss = py::none(); +}; + +std::unordered_map& GetTransformStates() { + // Intentionally leaked to avoid py::object teardown during Python finalization. + static auto* states = new std::unordered_map(); + return *states; +} + +TransformRuntimeState& GetTransformState(const py::object& self) { + auto* ptr = self.cast(); + auto& states = GetTransformStates(); + auto it = states.find(ptr); + if (it == states.end()) { + TransformRuntimeState state; + state.rules = GetRuleSetClass()(); + auto inserted = states.emplace(ptr, std::move(state)); + return inserted.first->second; + } + return it->second; +} + +py::dict GetNodeChildrenDict(const py::object& node) { + try { + return py::dict(node.attr("children")); + } catch (const py::error_already_set&) { + try { + return py::dict(node.attr("_children")); + } catch (const py::error_already_set&) { + return py::dict(); + } + } +} + +void SetNodeChildrenDict(const py::object& node, const py::dict& children) { + try { + node.attr("children") = children; + } catch (const py::error_already_set&) { + try { + node.attr("_children") = children; + } catch (const py::error_already_set&) { + } + } +} + +py::list GetRefNodes(const py::object& node, const py::list& ref_path) { + py::list current; + current.append(node); + for (auto key_obj : ref_path) { + const std::string key = py::cast(key_obj); + py::list next; + for (auto n_obj : current) { + py::object n = py::reinterpret_borrow(n_obj); + py::dict children = GetNodeChildrenDict(n); + py::object group_nodes = children.attr("get")(key, py::list()); + for (auto child : group_nodes) next.append(child); + } + current = next; + } + return current; +} + +py::object CallForward(const py::object& self, const py::object& node, const py::dict& kwargs) { + py::object forward = py::getattr(self, "forward", py::none()); + if (forward.is_none()) { + throw std::runtime_error("NodeTransform.forward is not implemented."); + } + if (kwargs.is_none() || kwargs.empty()) { + return forward(node); + } + return forward(node, **kwargs); +} + +void ExtendList(py::list& target, const py::list& src) { + for (auto item : src) target.append(item); +} + } // namespace void exportNodeTransform(py::module& m) { - py::class_(m, "NodeTransform") - .def(py::init(), py::arg("num_workers") = 0) + py::class_(m, "NodeTransform", py::dynamic_attr()) + .def(py::init([](int num_workers, + py::object rules, + bool /*return_trace*/, + py::kwargs /*kwargs*/) { + auto inst = std::make_unique(num_workers); + TransformRuntimeState state; + state.rules = rules.is_none() ? GetRuleSetClass()() : rules; + GetTransformStates()[inst.get()] = std::move(state); + return inst; + }), + py::arg("num_workers") = 0, + py::arg("rules") = py::none(), + py::arg("return_trace") = false + ) .def("batch_forward", - [](lazyllm::NodeTransform& self, + [](py::object self, py::object documents, const std::string& node_group, - py::object /*ref_path*/, - py::kwargs /*kwargs*/) { - std::vector docs; - if (py::isinstance(documents)) { - for (auto item : documents) docs.push_back(py::cast(item)); - } else - docs.push_back(documents.cast()); - return self.batch_forward(docs, node_group); + py::object ref_path, + py::kwargs kwargs) { + py::list docs; + if (py::isinstance(documents) || py::isinstance(documents)) { + for (auto item : documents) docs.append(py::reinterpret_borrow(item)); + } else { + docs.append(documents); + } + + py::list all_outputs; + const bool support_rich = py::bool_(py::getattr(self, "__support_rich__", py::bool_(false))); + + for (auto node_obj : docs) { + py::object node = py::reinterpret_borrow(node_obj); + py::dict children = GetNodeChildrenDict(node); + if (children.contains(py::str(node_group))) { + continue; + } + + py::list splits; + if (!ref_path.is_none()) { + py::list ref_nodes = GetRefNodes(node, py::cast(ref_path)); + if (ref_nodes.empty()) continue; + + py::dict forward_kwargs = py::dict(kwargs); + forward_kwargs["ref"] = ref_nodes; + if (support_rich) { + if (py::len(ref_nodes) == 1) { + splits = py::cast(CallForward(self, ref_nodes[0], forward_kwargs)); + } else { + py::object rich = GetRichDocNodeClass()(py::arg("nodes") = ref_nodes); + splits = py::cast(CallForward(self, rich, forward_kwargs)); + } + } else { + splits = py::list(); + for (auto ref_node_obj : ref_nodes) { + py::object ref_node = py::reinterpret_borrow(ref_node_obj); + py::list out = py::cast(CallForward(self, ref_node, forward_kwargs)); + ExtendList(splits, out); + } + } + } else { + if (IsRichDocNode(node) && !support_rich) { + splits = py::list(); + for (auto sub : node.attr("nodes")) { + py::object sub_node = py::reinterpret_borrow(sub); + py::list out = py::cast(CallForward(self, sub_node, kwargs)); + ExtendList(splits, out); + } + } else { + splits = py::cast(CallForward(self, node, kwargs)); + } + } + + for (auto s_obj : splits) { + py::object s = py::reinterpret_borrow(s_obj); + try { + s.attr("parent") = node; + } catch (const py::error_already_set&) { + } + py::setattr(s, "_group", py::str(node_group)); + } + children[py::str(node_group)] = splits; + SetNodeChildrenDict(node, children); + ExtendList(all_outputs, splits); + } + + return all_outputs; }, py::arg("documents"), py::arg("node_group"), @@ -56,25 +247,125 @@ void exportNodeTransform(py::module& m) { py::return_value_policy::reference ) .def("transform", - [](const lazyllm::NodeTransform& self, lazyllm::PDocNode node, py::kwargs /*kwargs*/) { - if (node == nullptr) return std::vector{}; - return self.transform(node); + [](py::object self, py::object node, py::kwargs kwargs) -> py::object { + if (node.is_none()) return py::object(py::list()); + return CallForward(self, node, py::dict(kwargs)); }, py::arg("document") ) - .def("__call__", - [](const lazyllm::NodeTransform& self, lazyllm::PDocNode node, py::kwargs /*kwargs*/) { - if (node == nullptr) return std::vector{}; - return self(node); + .def("forward", + [](py::object /*self*/, py::object /*node*/, py::kwargs /*kwargs*/) { + PyErr_SetString(PyExc_NotImplementedError, + "Subclasses must implement forward() to process a single DocNode or RichDocNode"); + throw py::error_already_set(); }, py::arg("node") ) + .def("__call__", + [](py::object self, py::object node_or_nodes, py::kwargs kwargs) -> py::object { + py::list results; + py::object forward_single = self.attr("_forward_single"); + if (py::isinstance(node_or_nodes) || py::isinstance(node_or_nodes)) { + for (auto item : node_or_nodes) { + py::object node = py::reinterpret_borrow(item); + if (!IsDocNode(node)) { + throw py::type_error( + "__call__() expects DocNode objects, got non-DocNode in list."); + } + py::list out = py::cast(forward_single(node, **kwargs)); + ExtendList(results, out); + } + return py::object(results); + } + + if (!IsDocNode(node_or_nodes)) { + throw py::type_error("__call__() expects DocNode or RichDocNode."); + } + return forward_single(node_or_nodes, **kwargs); + } + ) + .def("_forward_single", + [](py::object self, py::object node, py::kwargs kwargs) -> py::object { + const bool support_rich = py::bool_(py::getattr(self, "__support_rich__", py::bool_(false))); + if (IsRichDocNode(node) && !support_rich) { + py::list out; + for (auto sub : node.attr("nodes")) { + py::object sub_node = py::reinterpret_borrow(sub); + py::list res = py::cast(CallForward(self, sub_node, py::dict(kwargs))); + ExtendList(out, res); + } + return py::object(out); + } + return CallForward(self, node, py::dict(kwargs)); + } + ) + .def("process", + [](py::object self, py::list nodes, py::object on_match, py::object on_miss) { + py::object rules = py::getattr(self, "_rules", py::none()); + if (rules.is_none()) rules = GetRuleSetClass()(); + + py::object match_handler = on_match; + py::object miss_handler = on_miss; + + py::object instance_match = py::getattr(self, "_on_match", py::none()); + py::object instance_miss = py::getattr(self, "_on_miss", py::none()); + if (match_handler.is_none()) { + match_handler = instance_match.is_none() + ? py::getattr(self, "_default_match_handler") + : instance_match; + } + if (miss_handler.is_none()) { + miss_handler = instance_miss.is_none() + ? py::getattr(self, "_default_miss_handler") + : instance_miss; + } + + py::object ctx = GetContextClass()(py::arg("total") = py::len(nodes)); + py::list results; + size_t i = 0; + for (auto node_obj : nodes) { + py::object node = py::reinterpret_borrow(node_obj); + ctx.attr("current_idx") = i++; + py::object match = rules.attr("first")(node); + py::object processed = match.is_none() + ? miss_handler(node, ctx) + : match_handler(node, match, ctx); + results.append(processed); + ctx.attr("prev_node") = node; + ctx.attr("prev_result") = processed; + } + return results; + }, + py::arg("nodes"), + py::arg("on_match") = py::none(), + py::arg("on_miss") = py::none() + ) + .def("_default_match_handler", + [](py::object /*self*/, py::object /*node*/, py::object matched, py::object /*ctx*/) { + py::tuple tup = matched.cast(); + return tup[1]; + } + ) + .def("_default_miss_handler", + [](py::object /*self*/, py::object node, py::object /*ctx*/) { + return node; + } + ) .def( "with_name", - [](lazyllm::NodeTransform& self, py::object name, bool copy) -> lazyllm::NodeTransform& { - (void)copy; + [](py::object self, py::object name, bool copy) -> py::object { if (name.is_none()) return self; - self._name = name.cast(); + if (copy) { + try { + py::object copier = py::module_::import("copy").attr("copy"); + py::object new_self = copier(self); + new_self.attr("_name") = name; + return new_self; + } catch (const py::error_already_set&) { + // Fallback to in-place mutation if copy fails. + } + } + self.attr("_name") = name; return self; }, py::arg("name"), @@ -83,5 +374,25 @@ void exportNodeTransform(py::module& m) { py::return_value_policy::reference ) .def_readwrite("_name", &lazyllm::NodeTransform::_name) - .def_property_readonly("_number_workers", &lazyllm::NodeTransform::worker_num); + .def_property("_number_workers", + &lazyllm::NodeTransform::worker_num, + &lazyllm::NodeTransform::set_worker_num + ) + .def_property("_rules", + [](py::object self) { return GetTransformState(self).rules; }, + [](py::object self, py::object value) { + GetTransformState(self).rules = value.is_none() ? GetRuleSetClass()() : value; + } + ) + .def_property("_on_match", + [](py::object self) { return GetTransformState(self).on_match; }, + [](py::object self, py::object value) { GetTransformState(self).on_match = value; } + ) + .def_property("_on_miss", + [](py::object self) { return GetTransformState(self).on_miss; }, + [](py::object self, py::object value) { GetTransformState(self).on_miss = value; } + ) + ; + + m.attr("NodeTransform").attr("__support_rich__") = py::bool_(false); } diff --git a/csrc/binding/export_sentence_splitter.cpp b/csrc/binding/export_sentence_splitter.cpp index 3f32d2165..bc3d4efac 100644 --- a/csrc/binding/export_sentence_splitter.cpp +++ b/csrc/binding/export_sentence_splitter.cpp @@ -2,12 +2,17 @@ #include "sentence_splitter.hpp" -#include +#include +#include +#include #include +#include #include +#include #include #include +#include namespace { @@ -25,6 +30,84 @@ class PySentenceSplitter final : public lazyllm::SentenceSplitter { } }; +py::object GetBaseModule() { + return py::module_::import("lazyllm.tools.rag.transform.base"); +} + +py::object GetUnset() { + return GetBaseModule().attr("_UNSET"); +} + +bool IsUnset(const py::object& value) { + return value.is(GetUnset()); +} + +py::dict GetDefaultParams(const py::object& cls) { + if (!py::hasattr(cls, "_default_params")) { + cls.attr("_default_params") = py::dict(); + } + return cls.attr("_default_params").cast(); +} + +unsigned ResolveUnsigned( + const py::object& cls, + const std::string& param_name, + const py::object& value, + unsigned default_value +) { + if (value.is_none()) return default_value; + if (!IsUnset(value)) { + if (py::isinstance(value)) return value.cast(); + throw py::type_error(param_name + " must be an int"); + } + + py::dict defaults = GetDefaultParams(cls); + if (defaults.contains(param_name.c_str())) { + py::object v = py::reinterpret_borrow(defaults[param_name.c_str()]); + if (py::isinstance(v)) return v.cast(); + } + return default_value; +} + +std::string Trim(const std::string& input) { + size_t start = 0; + size_t end = input.size(); + while (start < end && std::isspace(static_cast(input[start]))) ++start; + while (end > start && std::isspace(static_cast(input[end - 1]))) --end; + return input.substr(start, end - start); +} + +void InitTiktokenTokenizer( + const py::object& py_self, + const std::string& encoding_name, + const py::object& model_name, + py::object allowed_special, + py::object disallowed_special +) { + (void)allowed_special; + (void)disallowed_special; + + std::string tokenizer_name = encoding_name; + if (!model_name.is_none()) { + if (!py::isinstance(model_name)) { + throw py::type_error("model_name must be a string"); + } + tokenizer_name = model_name.cast(); + } + auto tokenizer = std::make_shared(tokenizer_name); + + py::object encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer->encode(text); + }); + py::object decode = py::cpp_function([tokenizer](const std::vector& ids) { + py::bytes raw(tokenizer->decode(ids)); + return raw.attr("decode")("utf-8", "replace"); + }); + + py_self.attr("token_encoder") = encode; + py_self.attr("token_decoder") = decode; +} + } // namespace void exportSentenceSplitter(py::module& m) { @@ -32,22 +115,110 @@ void exportSentenceSplitter(py::module& m) { lazyllm::SentenceSplitter, lazyllm::TextSplitterBase, PySentenceSplitter - >(m, "SentenceSplitter") - .def(py::init([]( - std::optional chunk_size, - std::optional chunk_overlap, - std::optional num_workers, - py::object encoding_name - ) { - return std::make_unique( - chunk_size, - chunk_overlap, - num_workers, - encoding_name.is_none() ? "gpt2" : encoding_name.cast()); + >(m, "SentenceSplitter", py::dynamic_attr()) + .def(py::init([](py::object chunk_size, + py::object chunk_overlap, + py::object num_workers) { + py::object cls = py::module_::import("lazyllm.lazyllm_cpp").attr("SentenceSplitter"); + + unsigned cs = ResolveUnsigned(cls, "chunk_size", chunk_size, 1024); + unsigned ov = ResolveUnsigned(cls, "overlap", chunk_overlap, 200); + unsigned nw = ResolveUnsigned(cls, "num_workers", num_workers, 0); + + if (ov > cs) { + throw py::value_error( + "Got a larger chunk overlap (" + std::to_string(ov) + ") than chunk size (" + + std::to_string(cs) + "), should be smaller."); + } + if (cs == 0) { + throw py::value_error("chunk size should > 0 and overlap should >= 0"); + } + + return std::make_unique(cs, ov, nw, "gpt2"); }), py::arg("chunk_size") = py::none(), py::arg("chunk_overlap") = py::none(), - py::arg("num_workers") = py::none(), - py::arg("encoding_name") = py::none() + py::arg("num_workers") = py::none() + ) + .def("_merge", + [](lazyllm::SentenceSplitter& self, py::list splits, int chunk_size) { + if (py::len(splits) == 0) return std::vector{}; + + struct SplitData { std::string text; bool is_sentence; int token_size; }; + std::vector data; + data.reserve(py::len(splits)); + for (auto item : splits) { + py::object s = py::reinterpret_borrow(item); + data.push_back({ + s.attr("text").cast(), + s.attr("is_sentence").cast(), + s.attr("token_size").cast() + }); + } + + std::vector chunks; + std::vector> cur_chunk; + int cur_chunk_len = 0; + bool is_chunk_new = true; + int overlap = self.overlap(); + + auto close_chunk = [&]() { + std::string joined; + joined.reserve(256); + for (const auto& part : cur_chunk) joined += part.first; + chunks.push_back(joined); + + auto last_chunk = cur_chunk; + cur_chunk.clear(); + cur_chunk_len = 0; + is_chunk_new = true; + + int overlap_len = 0; + for (auto it = last_chunk.rbegin(); it != last_chunk.rend(); ++it) { + if (overlap_len + it->second > overlap) break; + cur_chunk.push_back(*it); + overlap_len += it->second; + cur_chunk_len += it->second; + } + std::reverse(cur_chunk.begin(), cur_chunk.end()); + }; + + size_t i = 0; + while (i < data.size()) { + const auto& cur_split = data[i]; + if (cur_split.token_size > chunk_size) { + throw py::value_error("Single token exceeded chunk size"); + } + if (cur_chunk_len + cur_split.token_size > chunk_size && !is_chunk_new) { + close_chunk(); + } else { + if (cur_split.is_sentence + || cur_chunk_len + cur_split.token_size <= chunk_size + || is_chunk_new) { + cur_chunk_len += cur_split.token_size; + cur_chunk.push_back({cur_split.text, cur_split.token_size}); + ++i; + is_chunk_new = false; + } else { + close_chunk(); + } + } + } + + if (!is_chunk_new) { + std::string joined; + joined.reserve(256); + for (const auto& part : cur_chunk) joined += part.first; + chunks.push_back(joined); + } + + std::vector out; + for (const auto& chunk : chunks) { + std::string trimmed = Trim(chunk); + if (!trimmed.empty()) out.push_back(trimmed); + } + return out; + }, + py::arg("splits"), py::arg("chunk_size") ); } diff --git a/csrc/binding/export_text_splitter_base.cpp b/csrc/binding/export_text_splitter_base.cpp index 42bbeb3d7..9305f163a 100644 --- a/csrc/binding/export_text_splitter_base.cpp +++ b/csrc/binding/export_text_splitter_base.cpp @@ -2,116 +2,293 @@ #include "text_splitter_base.hpp" -#include +#include #include -#include #include #include -#include +#include #include #include #include #include +#include namespace { using SplitFn = lazyllm::TextSplitterBase::SplitFn; -std::any py_to_any(const py::handle& value) { - if (py::isinstance(value)) return value.cast(); - if (py::isinstance(value)) return value.cast(); - if (py::isinstance(value)) return value.cast(); - if (py::isinstance(value)) return value.cast(); - throw std::runtime_error("Unsupported default parameter type."); +class PyTextSplitterBase final : public lazyllm::TextSplitterBase { +public: + using lazyllm::TextSplitterBase::TextSplitterBase; + + std::vector transform(const lazyllm::PDocNode node) const override { + PYBIND11_OVERRIDE( + std::vector, + lazyllm::TextSplitterBase, + transform, + node + ); + } +}; + +py::object GetBaseModule() { + return py::module_::import("lazyllm.tools.rag.transform.base"); } -py::object any_to_py(const std::any& value) { - if (value.type() == typeid(bool)) return py::bool_(std::any_cast(value)); - if (value.type() == typeid(int)) return py::int_(std::any_cast(value)); - if (value.type() == typeid(double)) return py::float_(std::any_cast(value)); - if (value.type() == typeid(std::string)) return py::str(std::any_cast(value)); - return py::none(); +py::object GetUnset() { + return GetBaseModule().attr("_UNSET"); } -class PyTokenizer final : public Tokenizer { -public: - enum class Mode { - Generic, - HuggingFace - }; - - explicit PyTokenizer(py::object obj, Mode mode = Mode::Generic) - : _obj(std::move(obj)), _mode(mode) {} - - std::vector encode(const std::string_view& text) const override { - py::gil_scoped_acquire gil; - py::object func = py::getattr(_obj, "encode", py::none()); - if (func.is_none()) throw std::runtime_error("Tokenizer missing method: encode"); - - py::object result; - if (_mode == Mode::HuggingFace) { - result = func(std::string(text), py::arg("add_special_tokens") = false); +bool IsUnset(const py::object& value) { + return value.is(GetUnset()); +} + +py::dict GetDefaultParams(const py::object& cls) { + if (!py::hasattr(cls, "_default_params")) { + cls.attr("_default_params") = py::dict(); + } + return cls.attr("_default_params").cast(); +} + +unsigned ResolveUnsigned( + const py::object& cls, + const std::string& param_name, + const py::object& value, + unsigned default_value +) { + if (value.is_none()) return default_value; + if (!IsUnset(value)) { + if (py::isinstance(value)) return value.cast(); + throw py::type_error(param_name + " must be an int"); + } + + py::dict defaults = GetDefaultParams(cls); + if (defaults.contains(param_name.c_str())) { + py::object v = py::reinterpret_borrow(defaults[param_name.c_str()]); + if (py::isinstance(v)) return v.cast(); + } + return default_value; +} + +py::object GetSplitClass() { + return GetBaseModule().attr("_Split"); +} + +py::object MakeSplit(const std::string& text, bool is_sentence, int token_size) { + return GetSplitClass()(py::arg("text") = text, + py::arg("is_sentence") = is_sentence, + py::arg("token_size") = token_size); +} + +std::pair, bool> GetSplitsByFns(const std::string& text) { + py::object base = GetBaseModule(); + py::object split_keep = base.attr("split_text_keep_separator"); + py::object nltk = base.attr("nltk"); + py::object tokenizer = nltk.attr("tokenize").attr("PunktSentenceTokenizer")(); + + py::object splits_obj = split_keep(text, "\n\n\n"); + auto splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, true}; + + splits_obj = tokenizer.attr("tokenize")(text); + splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, true}; + + py::object re = base.attr("re"); + splits_obj = re.attr("findall")(py::str(R"([^,.;。?!]+[,.;。?!]?)"), text); + splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, false}; + + splits_obj = split_keep(text, " "); + splits = splits_obj.cast>(); + if (splits.size() > 1) return {splits, false}; + + splits_obj = py::list(py::str(text)); + splits = splits_obj.cast>(); + return {splits, false}; +} + +void InitTiktokenTokenizer( + const py::object& py_self, + const std::string& encoding_name, + const py::object& model_name, + py::object allowed_special, + py::object disallowed_special +); + +void EnsureTokenCodec(const py::object& self) { + py::dict obj_dict = py::getattr(self, "__dict__", py::dict()).cast(); + const bool has_encoder = + obj_dict.contains("token_encoder") && + !py::reinterpret_borrow(obj_dict["token_encoder"]).is_none(); + const bool has_decoder = + obj_dict.contains("token_decoder") && + !py::reinterpret_borrow(obj_dict["token_decoder"]).is_none(); + if (!has_encoder || !has_decoder) { + InitTiktokenTokenizer(self, "gpt2", py::none(), py::none(), py::str("all")); + } +} + +int TokenSize(const py::object& self, const std::string& text) { + EnsureTokenCodec(self); + py::object encoder = self.attr("token_encoder"); + py::object tokens = encoder(text); + return static_cast(py::len(tokens)); +} + +void ExtendList(py::list& target, const py::list& src) { + for (auto item : src) target.append(item); +} + +py::list SplitRecursive(const py::object& self, const std::string& text, int chunk_size) { + const int token_size = TokenSize(self, text); + if (token_size <= chunk_size) { + py::list out; + out.append(MakeSplit(text, true, token_size)); + return out; + } + + auto [text_splits, is_sentence] = GetSplitsByFns(text); + py::list results; + for (const auto& segment : text_splits) { + const int seg_token_size = TokenSize(self, segment); + if (seg_token_size <= chunk_size) { + results.append(MakeSplit(segment, is_sentence, seg_token_size)); } else { - result = func(std::string(text)); + py::list sub = SplitRecursive(self, segment, chunk_size); + ExtendList(results, sub); } - return result.cast>(); } + return results; +} - std::string decode(const std::vector& token_ids) const override { - py::gil_scoped_acquire gil; - py::object func = py::getattr(_obj, "decode", py::none()); - if (func.is_none()) throw std::runtime_error("Tokenizer missing method: decode"); +void InitTiktokenTokenizer( + const py::object& py_self, + const std::string& encoding_name, + const py::object& model_name, + py::object allowed_special, + py::object disallowed_special +) { + (void)allowed_special; + (void)disallowed_special; - py::object result; - if (_mode == Mode::HuggingFace) { - result = func(token_ids, py::arg("skip_special_tokens") = true); - } else { - result = func(token_ids); + std::string tokenizer_name = encoding_name; + if (!model_name.is_none()) { + if (!py::isinstance(model_name)) { + throw py::type_error("model_name must be a string"); } - return result.cast(); + tokenizer_name = model_name.cast(); } + auto tokenizer = std::make_shared(tokenizer_name); -private: - py::object _obj; - Mode _mode; -}; + py::object encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer->encode(text); + }); + py::object decode = py::cpp_function([tokenizer](const std::vector& ids) { + py::bytes raw(tokenizer->decode(ids)); + return raw.attr("decode")("utf-8", "replace"); + }); -class PyTextSplitterBase final : public lazyllm::TextSplitterBase { -public: - using lazyllm::TextSplitterBase::TextSplitterBase; + py_self.attr("token_encoder") = encode; + py_self.attr("token_decoder") = decode; +} - std::vector transform(const lazyllm::PDocNode node) const override { - PYBIND11_OVERRIDE( - std::vector, - lazyllm::TextSplitterBase, - transform, - node - ); +void InitTokenizerFromObject(const py::object& py_self, const py::object& tokenizer, bool huggingface) { + py::object encode; + py::object decode; + if (huggingface) { + encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer.attr("encode")(text, py::arg("add_special_tokens") = false); + }); + decode = py::cpp_function([tokenizer](const std::vector& ids) { + return tokenizer.attr("decode")(ids, py::arg("skip_special_tokens") = true); + }); + } else { + encode = py::cpp_function([tokenizer](const std::string& text) { + return tokenizer.attr("encode")(text); + }); + decode = py::cpp_function([tokenizer](const std::vector& ids) { + return tokenizer.attr("decode")(ids); + }); } -}; + + py_self.attr("token_encoder") = encode; + py_self.attr("token_decoder") = decode; +} } // namespace void exportTextSpliterBase(py::module& m) { - py::class_(m, "_TextSplitterBase") - .def(py::init([]( - std::optional chunk_size, - std::optional overlap, - std::optional num_workers, - py::object encoding_name) { - return std::make_unique( - chunk_size, overlap, num_workers, encoding_name.is_none() ? "gpt2" : encoding_name.cast()); + py::class_( + m, "_TextSplitterBase", py::dynamic_attr()) + .def(py::init([](py::object chunk_size, + py::object overlap, + py::object num_workers) { + py::object cls = py::module_::import("lazyllm.lazyllm_cpp").attr("_TextSplitterBase"); + + unsigned cs = ResolveUnsigned(cls, "chunk_size", chunk_size, 1024); + unsigned ov = ResolveUnsigned(cls, "overlap", overlap, 200); + unsigned nw = ResolveUnsigned(cls, "num_workers", num_workers, 0); + + if (ov > cs) { + throw py::value_error( + "Got a larger chunk overlap (" + std::to_string(ov) + ") than chunk size (" + + std::to_string(cs) + "), should be smaller."); + } + if (cs == 0) { + throw py::value_error("chunk size should > 0 and overlap should >= 0"); + } + + return std::make_unique(cs, ov, nw, "gpt2"); }), py::arg("chunk_size") = py::none(), py::arg("overlap") = py::none(), - py::arg("num_workers") = py::none(), - py::arg("encoding_name") = py::none() + py::arg("num_workers") = py::none() + ) + .def_property_readonly("_chunk_size", &lazyllm::TextSplitterBase::chunk_size) + .def_property_readonly("_overlap", &lazyllm::TextSplitterBase::overlap) + .def("__getattr__", + [](py::object self, const std::string& name) -> py::object { + if (name == "token_encoder" || name == "token_decoder") { + EnsureTokenCodec(self); + py::dict obj_dict = py::getattr(self, "__dict__", py::dict()).cast(); + if (obj_dict.contains(name.c_str())) { + return py::reinterpret_borrow(obj_dict[name.c_str()]); + } + } + throw py::attribute_error( + "'" + std::string(py::str(py::type::of(self).attr("__name__"))) + + "' object has no attribute '" + name + "'"); + }, + py::arg("name") ) .def("split_text", - [](const lazyllm::TextSplitterBase& self, const std::string& text, int metadata_size) { + [](lazyllm::TextSplitterBase& self, const std::string& text, int metadata_size) { if (text.empty()) return std::vector{""}; - return self.split_text(text, metadata_size); + + const int chunk_size = self.chunk_size(); + const int effective_chunk_size = chunk_size - metadata_size; + if (effective_chunk_size <= 0) { + throw py::value_error( + "Metadata length (" + std::to_string(metadata_size) + ") is longer than chunk size (" + + std::to_string(chunk_size) + "). Consider increasing the chunk size or decreasing the size of " + "your metadata to avoid this."); + } else if (effective_chunk_size < 50) { + try { + py::object log = py::module_::import("lazyllm").attr("LOG"); + log.attr("warning")( + "Metadata length (" + std::to_string(metadata_size) + ") is close to chunk size (" + + std::to_string(chunk_size) + "). Resulting chunks are less than 50 tokens. " + "Consider increasing the chunk size or decreasing the size of your metadata to avoid this."); + } catch (const py::error_already_set&) { + } + } + + py::object py_self = py::cast(&self, py::return_value_policy::reference); + py::object splits = py_self.attr("_split")(text, effective_chunk_size); + py::object chunks = py_self.attr("_merge")(splits, effective_chunk_size); + return chunks.cast>(); }, py::arg("text"), py::arg("metadata_size") ) @@ -126,97 +303,235 @@ void exportTextSpliterBase(py::module& m) { py::arg("text"), py::arg("separator") ) .def("from_tiktoken_encoder", - [](lazyllm::TextSplitterBase& self, + [](py::object self, const std::string& encoding_name, py::object model_name, - py::object /*allowed_special*/, - py::object /*disallowed_special*/, - py::kwargs /*kwargs*/) -> lazyllm::TextSplitterBase& { - if (model_name.is_none()) { - return self.from_tiktoken_encoder(encoding_name, std::nullopt); - } - return self.from_tiktoken_encoder(encoding_name, model_name.cast()); + py::object allowed_special, + py::object disallowed_special, + py::kwargs /*kwargs*/) { + InitTiktokenTokenizer(self, encoding_name, model_name, allowed_special, disallowed_special); + return self; }, py::arg("encoding_name") = "gpt2", py::arg("model_name") = py::none(), py::arg("allowed_special") = py::none(), - py::arg("disallowed_special") = "all", - py::return_value_policy::reference + py::arg("disallowed_special") = "all" ) .def("from_tiktoken_encoding", - [](lazyllm::TextSplitterBase& self, const std::string& encoding_name) -> lazyllm::TextSplitterBase& { - return self.from_tiktoken_encoder(encoding_name, std::nullopt); + [](py::object self, const std::string& encoding_name) { + InitTiktokenTokenizer(self, encoding_name, py::none(), py::none(), py::str("all")); + return self; }, - py::arg("encoding_name") = "gpt2", - py::return_value_policy::reference + py::arg("encoding_name") = "gpt2" ) .def("from_tokenizer", - [](lazyllm::TextSplitterBase& self, py::object tokenizer) -> lazyllm::TextSplitterBase& { - self.set_tokenizer(std::make_shared(std::move(tokenizer), PyTokenizer::Mode::Generic)); + [](py::object self, py::object tokenizer) { + InitTokenizerFromObject(self, tokenizer, false); return self; }, - py::arg("tokenizer"), - py::return_value_policy::reference + py::arg("tokenizer") ) .def("from_huggingface_tokenizer", - [](lazyllm::TextSplitterBase& self, py::object tokenizer) -> lazyllm::TextSplitterBase& { - self.set_tokenizer(std::make_shared(std::move(tokenizer), PyTokenizer::Mode::HuggingFace)); + [](py::object self, py::object tokenizer) { + InitTokenizerFromObject(self, tokenizer, true); return self; }, - py::arg("tokenizer"), - py::return_value_policy::reference + py::arg("tokenizer") ) .def("set_split_fns", - [](lazyllm::TextSplitterBase& self, const std::vector& split_fns, py::object sub_split_fns) { - if (sub_split_fns.is_none()) { - self.set_split_functions(split_fns, std::nullopt); - return; - } - self.set_split_functions(split_fns, sub_split_fns.cast>()); + [](py::object /*self*/, const py::object& /*split_fns*/, py::object /*sub_split_fns*/) { + return; }, py::arg("split_fns"), py::arg("sub_split_fns") = py::none() ) .def("add_split_fn", - [](lazyllm::TextSplitterBase& self, const SplitFn& split_fn, py::object index) { - if (index.is_none()) { - self.add_split_function(split_fn, std::nullopt); - return; - } - self.add_split_function(split_fn, index.cast()); + [](py::object /*self*/, const py::object& /*split_fn*/, py::object /*index*/) { + return; }, py::arg("split_fn"), py::arg("index") = py::none() ) - .def("clear_split_fns", &lazyllm::TextSplitterBase::clear_split_functions) - .def_static("set_default", - [](py::kwargs kwargs) { - lazyllm::MapParams::MapType updates; - for (auto item : kwargs) { - const auto key = py::cast(item.first); - updates[key] = py_to_any(item.second); - } - lazyllm::TextSplitterBase::_default_params.set_default(updates); + .def("clear_split_fns", + [](py::object /*self*/) { + return; } ) - .def_static("get_default", - // TODO: too verbose - [](py::object param_name) -> py::object { - const auto defaults = lazyllm::TextSplitterBase::_default_params.get_default(); - if (param_name.is_none()) { - py::dict out; - for (const auto& [key, value] : defaults) { - out[py::str(key)] = any_to_py(value); + .def("_token_size", + [](py::object self, const std::string& text) { + return TokenSize(self, text); + }, + py::arg("text") + ) + .def("_get_metadata_size", + [](py::object self, py::object node) { + py::object meta_mode = GetBaseModule().attr("MetadataMode"); + std::string embed = node.attr("get_metadata_str")(meta_mode.attr("EMBED")).cast(); + std::string llm = node.attr("get_metadata_str")(meta_mode.attr("LLM")).cast(); + return std::max(TokenSize(self, embed), TokenSize(self, llm)); + }, + py::arg("node") + ) + .def("_get_splits_by_fns", + [](py::object /*self*/, const std::string& text) { + auto [splits, is_sentence] = GetSplitsByFns(text); + return py::make_tuple(splits, is_sentence); + }, + py::arg("text") + ) + .def("_split", + [](py::object self, const std::string& text, int chunk_size) { + return SplitRecursive(self, text, chunk_size); + }, + py::arg("text"), py::arg("chunk_size") + ) + .def("_merge", + [](lazyllm::TextSplitterBase& self, py::list splits, int chunk_size) { + if (py::len(splits) == 0) return std::vector{}; + if (py::len(splits) == 1) { + py::object s = splits[0]; + return std::vector{py::cast(s.attr("text"))}; + } + + struct SplitData { std::string text; bool is_sentence; int token_size; }; + std::vector data; + data.reserve(py::len(splits)); + for (auto item : splits) { + py::object s = py::reinterpret_borrow(item); + data.push_back({ + s.attr("text").cast(), + s.attr("is_sentence").cast(), + s.attr("token_size").cast() + }); + } + + const int overlap = self.overlap(); + py::object py_self = py::cast(&self, py::return_value_policy::reference); + EnsureTokenCodec(py_self); + py::object encoder = py_self.attr("token_encoder"); + py::object decoder = py_self.attr("token_decoder"); + + if (data.back().token_size == chunk_size && overlap > 0) { + SplitData end_split = data.back(); + data.pop_back(); + auto tokens = encoder(end_split.text).cast>(); + const size_t half = tokens.size() / 2; + auto mid = static_cast::difference_type>(half); + std::vector prefix(tokens.begin(), tokens.begin() + mid); + std::vector suffix(tokens.begin() + mid, tokens.end()); + std::string p_text = decoder(prefix).cast(); + std::string n_text = decoder(suffix).cast(); + data.push_back({p_text, end_split.is_sentence, TokenSize(py_self, p_text)}); + data.push_back({n_text, end_split.is_sentence, TokenSize(py_self, n_text)}); + } + + SplitData end_split = data.back(); + std::vector result; + for (int idx = static_cast(data.size()) - 2; idx >= 0; --idx) { + SplitData start_split = data[static_cast(idx)]; + if (start_split.token_size <= overlap && end_split.token_size <= chunk_size - overlap) { + end_split.text = start_split.text + end_split.text; + end_split.is_sentence = start_split.is_sentence && end_split.is_sentence; + end_split.token_size += start_split.token_size; + continue; } - return py::object(std::move(out)); + + if (end_split.token_size > chunk_size) { + throw py::value_error( + "split token size (" + std::to_string(end_split.token_size) + ") is greater than chunk size (" + + std::to_string(chunk_size) + ")."); + } + + const int remaining_space = chunk_size - end_split.token_size; + const int overlap_len = std::min({overlap, remaining_space, start_split.token_size}); + if (overlap_len > 0) { + auto start_tokens = encoder(start_split.text).cast>(); + std::vector overlap_tokens(start_tokens.end() - overlap_len, start_tokens.end()); + std::string overlap_text = decoder(overlap_tokens).cast(); + end_split.text = overlap_text + end_split.text; + end_split.token_size += overlap_len; + } + + result.insert(result.begin(), end_split.text); + end_split = start_split; } - std::string key = param_name.cast(); - auto it = defaults.find(key); - if (it == defaults.end() && key == "num_workers") it = defaults.find("worker_num"); - else if (it == defaults.end() && key == "worker_num") it = defaults.find("num_workers"); - if (it == defaults.end()) return py::none(); - return any_to_py(it->second); + result.insert(result.begin(), end_split.text); + return result; + }, + py::arg("splits"), py::arg("chunk_size") + ) + .def("forward", + [](py::object self, py::object node, py::kwargs /*kwargs*/) { + const std::string text = node.attr("get_text")().cast(); + const int metadata_size = self.attr("_get_metadata_size")(node).cast(); + py::object chunks = self.attr("split_text")(text, metadata_size); + + std::vector out; + py::object doc_cls = py::module_::import("lazyllm.tools.rag.doc_node").attr("DocNode"); + for (auto item : chunks) { + py::object chunk = py::reinterpret_borrow(item); + if (chunk.is_none()) continue; + if (py::isinstance(chunk, doc_cls)) { + out.push_back(chunk.cast()); + continue; + } + std::string chunk_text = chunk.cast(); + if (chunk_text.empty()) continue; + out.push_back(std::make_shared(std::move(chunk_text))); + } + return out; }, - py::arg("param_name") = py::none() + py::arg("node") ) - .def_static("reset_default", []() { lazyllm::TextSplitterBase::_default_params.reset_default(); }); + .def("_get_param_value", + [](py::object self, const std::string& param_name, py::object value, py::object default_value) { + if (!IsUnset(value)) return value; + py::object cls = py::type::of(self); + py::dict defaults = GetDefaultParams(cls); + if (defaults.contains(param_name.c_str())) { + return py::reinterpret_borrow(defaults[param_name.c_str()]); + } + return default_value; + }, + py::arg("param_name"), py::arg("value"), py::arg("default") + ); + + py::object py_cls = m.attr("_TextSplitterBase"); + py::object builtins = py::module_::import("builtins"); + + py_cls.attr("_get_class_lock") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass) { + if (!py::hasattr(klass, "_default_params_lock")) { + klass.attr("_default_params_lock") = py::module_::import("threading").attr("RLock")(); + } + return klass.attr("_default_params_lock"); + }) + ); + + py_cls.attr("set_default") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass, py::kwargs kwargs) { + py::dict defaults = GetDefaultParams(klass); + for (auto item : kwargs) { + defaults[item.first] = item.second; + } + klass.attr("_default_params") = defaults; + }) + ); + + py_cls.attr("get_default") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass, py::object param_name) -> py::object { + py::dict defaults = GetDefaultParams(klass); + if (param_name.is_none()) return py::object(defaults); + const std::string key = param_name.cast(); + if (defaults.contains(key.c_str())) { + return py::reinterpret_borrow(defaults[key.c_str()]); + } + return py::none(); + }) + ); + + py_cls.attr("reset_default") = builtins.attr("classmethod")( + py::cpp_function([](py::object klass) { + klass.attr("_default_params") = py::dict(); + }) + ); } diff --git a/csrc/core/include/doc_node.hpp b/csrc/core/include/doc_node.hpp index e86b2de97..aa54c398c 100644 --- a/csrc/core/include/doc_node.hpp +++ b/csrc/core/include/doc_node.hpp @@ -9,8 +9,8 @@ #include #include #include -#include #include +#include #include "utils.hpp" #include "adaptor_base.hpp" @@ -24,7 +24,8 @@ using PDocNode = std::shared_ptr; class DocNode { public: - using Metadata = std::unordered_map; + using MetadataVType = lazyllm::MetadataVType; + using Metadata = std::unordered_map; using Children = std::unordered_map>; using EmbeddingFun = std::function(const std::string&, const std::string&)>; using EmbeddingVecs = std::unordered_map>; @@ -95,6 +96,7 @@ class DocNode { } void set_store(const std::shared_ptr& p_store) { _p_store = p_store; } const std::string& get_uid() const { return _uid; } + void set_uid(const std::string& uid) { _uid = uid; } const std::string_view& get_text_view() const { return _text_view; } void set_text_view(const std::string_view& text_view) { _text_view = text_view; @@ -141,6 +143,12 @@ class DocNode { _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); return _children; } + Children& get_children_ref() { + if (_children.empty() && _p_store != nullptr) { + _children = std::any_cast(_p_store->call("get_node_children", {{"node", this}})); + } + return _children; + } void set_children(const Children& children) { _children = children; } void set_children_group( const std::string& group_name, @@ -162,7 +170,8 @@ class DocNode { _excluded_llm_metadata_keys = keys; } std::string get_doc_path() const { - return std::any_cast(get_root_node()->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_PATH))); + const auto& value = get_root_node()->_p_global_metadata->at(std::string(RAGMetadataKeys::DOC_PATH)); + return std::get(value); } void set_doc_path(const std::string& path) { get_root_node()->_p_global_metadata->operator[](std::string(RAGMetadataKeys::DOC_PATH)) = path; diff --git a/csrc/core/include/node_transform.hpp b/csrc/core/include/node_transform.hpp index c31650374..ca6a8f195 100644 --- a/csrc/core/include/node_transform.hpp +++ b/csrc/core/include/node_transform.hpp @@ -29,6 +29,7 @@ class NodeTransform { std::vector batch_forward(std::vector& nodes, const std::string& node_group_name); int worker_num() const { return _worker_num; } + void set_worker_num(int worker_num) { _worker_num = worker_num; } private: std::vector forward(PDocNode node, const std::string& node_group_name); diff --git a/csrc/core/include/text_splitter_base.hpp b/csrc/core/include/text_splitter_base.hpp index 501b63005..20e30f86c 100644 --- a/csrc/core/include/text_splitter_base.hpp +++ b/csrc/core/include/text_splitter_base.hpp @@ -42,6 +42,7 @@ class TextSplitterBase : public NodeTransform { } std::vector transform(PDocNode node) const override { + if (node == nullptr) return {}; auto chunks = split_text(node->get_text_view(), get_node_metadata_size(*node)); std::vector nodes; nodes.reserve(chunks.size()); @@ -64,6 +65,8 @@ class TextSplitterBase : public NodeTransform { } void set_tokenizer(std::shared_ptr tokenizer) { _tokenizer = std::move(tokenizer); } + int chunk_size() const { return _chunk_size; } + int overlap() const { return _overlap; } virtual void set_split_functions( const std::vector&, diff --git a/csrc/core/include/tokenizer.hpp b/csrc/core/include/tokenizer.hpp index dd667b178..a7f2aada8 100644 --- a/csrc/core/include/tokenizer.hpp +++ b/csrc/core/include/tokenizer.hpp @@ -2,13 +2,17 @@ #include #include +#include +#include #include #include #include #include +#include #include #include +#include #include class Tokenizer { @@ -18,11 +22,35 @@ class Tokenizer { virtual std::string decode(const std::vector& token_ids) const = 0; }; +class FallbackByteTokenizer final : public Tokenizer { +public: + FallbackByteTokenizer() = default; + ~FallbackByteTokenizer() override = default; + + std::vector encode(const std::string_view& view) const override { + std::vector token_ids; + token_ids.reserve(view.size()); + for (unsigned char ch : view) { + token_ids.push_back(static_cast(ch)); + } + return token_ids; + } + + std::string decode(const std::vector& token_ids) const override { + std::string text; + text.reserve(token_ids.size()); + for (int id : token_ids) { + text.push_back(static_cast(id & 0xFF)); + } + return text; + } +}; + class TiktokenTokenizer final : public Tokenizer { public: TiktokenTokenizer() = delete; explicit TiktokenTokenizer(LanguageModel model) - : _encoding(GptEncoding::get_encoding(model)) {} + : _encoding(load_encoding(model)) {} explicit TiktokenTokenizer(std::string_view encoding_name) : TiktokenTokenizer(parse_tiktoken_model(encoding_name)) {} @@ -38,9 +66,79 @@ class TiktokenTokenizer final : public Tokenizer { } private: + class FilePathResourceReader final : public IResourceReader { + public: + explicit FilePathResourceReader(std::filesystem::path resource_path) + : resource_path_(std::move(resource_path)) {} + + std::vector readLines() override { + std::ifstream file(resource_path_); + if (!file.is_open()) { + throw std::runtime_error("Embedded resource '" + resource_path_.string() + "' not found."); + } + std::string line; + std::vector lines; + while (std::getline(file, line)) lines.push_back(line); + return lines; + } + + private: + std::filesystem::path resource_path_; + }; + + static std::string resource_name(LanguageModel model) { + switch (model) { + case LanguageModel::R50K_BASE: return "r50k_base.tiktoken"; + case LanguageModel::P50K_BASE: return "p50k_base.tiktoken"; + case LanguageModel::P50K_EDIT: return "p50k_base.tiktoken"; + case LanguageModel::CL100K_BASE: return "cl100k_base.tiktoken"; + case LanguageModel::O200K_BASE: return "o200k_base.tiktoken"; + case LanguageModel::QWEN_BASE: return "qwen.tiktoken"; + } + throw std::runtime_error("Unknown language model"); + } + + static std::shared_ptr load_encoding(LanguageModel model) { + try { + return GptEncoding::get_encoding(model); + } catch (const std::exception&) { + const std::filesystem::path repo_root = + std::filesystem::path(__FILE__).parent_path().parent_path().parent_path().parent_path(); + const std::string file_name = resource_name(model); + const std::vector candidates = { + repo_root / "build" / "tokenizers" / file_name, + repo_root / "tokenizers" / file_name, + std::filesystem::current_path() / "build" / "tokenizers" / file_name, + std::filesystem::current_path() / "tokenizers" / file_name + }; + for (const auto& path : candidates) { + if (!std::filesystem::exists(path)) continue; + FilePathResourceReader reader(path); + return GptEncoding::get_encoding(model, &reader); + } + throw; + } + } + + static bool has_prefix(std::string_view value, std::string_view prefix) { + return value.size() >= prefix.size() && value.compare(0, prefix.size(), prefix) == 0; + } + static LanguageModel parse_tiktoken_model(std::string_view name) { if (name.empty()) return LanguageModel::R50K_BASE; + // Model-name aliases used by Python tiktoken.encoding_for_model. + if (name == "gpt-3.5-turbo" || has_prefix(name, "gpt-3.5-turbo-")) return LanguageModel::CL100K_BASE; + if (name == "gpt-4" || has_prefix(name, "gpt-4-")) return LanguageModel::CL100K_BASE; + if (name == "text-embedding-ada-002") return LanguageModel::CL100K_BASE; + if (name == "text-embedding-3-small" || name == "text-embedding-3-large") return LanguageModel::CL100K_BASE; + if (name == "gpt-4o" || has_prefix(name, "gpt-4o-")) return LanguageModel::O200K_BASE; + if (name == "gpt-4.1" || has_prefix(name, "gpt-4.1-")) return LanguageModel::O200K_BASE; + if (name == "gpt-4.5" || has_prefix(name, "gpt-4.5-")) return LanguageModel::O200K_BASE; + if (name == "o1" || has_prefix(name, "o1-")) return LanguageModel::O200K_BASE; + if (name == "o3" || has_prefix(name, "o3-")) return LanguageModel::O200K_BASE; + if (name == "o4-mini" || has_prefix(name, "o4-mini-")) return LanguageModel::O200K_BASE; + if (name == "gpt2" || name == "r50k_base" || name == "r50k") return LanguageModel::R50K_BASE; if (name == "p50k_base" || name == "p50k") return LanguageModel::P50K_BASE; if (name == "p50k_edit") return LanguageModel::P50K_EDIT; diff --git a/csrc/core/include/utils.hpp b/csrc/core/include/utils.hpp index 023575555..bf9bb1f19 100644 --- a/csrc/core/include/utils.hpp +++ b/csrc/core/include/utils.hpp @@ -7,10 +7,18 @@ #include #include #include +#include #include namespace lazyllm { +using MetadataVType = std::variant< + std::string, std::vector, + int, std::vector, + double, std::vector, + std::any +>; + struct RAGMetadataKeys { inline static constexpr std::string_view KB_ID = "kb_id"; inline static constexpr std::string_view DOC_ID = "docid"; @@ -81,7 +89,7 @@ inline std::string GenerateUUID() { return out; } -std::string any_to_string(const std::any& value); +std::string any_to_string(const MetadataVType& value); inline bool is_adjacent(const std::string_view& left, const std::string_view& right) { return left.data() + left.size() == right.data(); diff --git a/csrc/core/src/utils.cpp b/csrc/core/src/utils.cpp index f6dd7736c..620fdb066 100644 --- a/csrc/core/src/utils.cpp +++ b/csrc/core/src/utils.cpp @@ -1,17 +1,48 @@ #include "utils.hpp" +#include + namespace lazyllm { -std::string any_to_string(const std::any& value) { +namespace { + +std::string ScalarToString(const std::string& value) { + return value; +} + +std::string ScalarToString(int value) { + return std::to_string(value); +} + +std::string ScalarToString(double value) { + std::ostringstream oss; + oss << value; + return oss.str(); +} + +template +std::string VectorToString(const std::vector& values) { + std::string out = "["; + for (size_t i = 0; i < values.size(); ++i) { + out += ScalarToString(values[i]); + if (i + 1 < values.size()) out += ","; + } + out += "]"; + return out; +} + +std::string AnyToString(const std::any& value) { + if (!value.has_value()) return ""; const auto& t = value.type(); + if (t == typeid(std::string)) return std::any_cast(value); if (t == typeid(const char*)) return std::string(std::any_cast(value)); if (t == typeid(char*)) return std::string(std::any_cast(value)); - if (t == typeid(bool)) return std::any_cast(value) ? "True" : "False"; + if (t == typeid(bool)) return std::any_cast(value) ? "1" : "0"; if (t == typeid(int)) return std::to_string(std::any_cast(value)); if (t == typeid(long)) return std::to_string(std::any_cast(value)); if (t == typeid(long long)) return std::to_string(std::any_cast(value)); - if (t == typeid(unsigned)) return std::to_string(std::any_cast(value)); + if (t == typeid(unsigned int)) return std::to_string(std::any_cast(value)); if (t == typeid(unsigned long)) return std::to_string(std::any_cast(value)); if (t == typeid(unsigned long long)) return std::to_string(std::any_cast(value)); if (t == typeid(float)) { @@ -24,7 +55,29 @@ std::string any_to_string(const std::any& value) { oss << std::any_cast(value); return oss.str(); } - return ""; + if (t == typeid(std::vector)) + return VectorToString(std::any_cast&>(value)); + if (t == typeid(std::vector)) + return VectorToString(std::any_cast&>(value)); + if (t == typeid(std::vector)) + return VectorToString(std::any_cast&>(value)); + return ""; +} + +} // namespace + +std::string any_to_string(const MetadataVType& value) { + return std::visit([](const auto& v) -> std::string { + using T = std::decay_t; + if constexpr (std::is_same_v) return ScalarToString(v); + if constexpr (std::is_same_v) return ScalarToString(v); + if constexpr (std::is_same_v) return ScalarToString(v); + if constexpr (std::is_same_v>) return VectorToString(v); + if constexpr (std::is_same_v>) return VectorToString(v); + if constexpr (std::is_same_v>) return VectorToString(v); + if constexpr (std::is_same_v) return AnyToString(v); + return std::string(""); + }, value); } -} \ No newline at end of file +} // namespace lazyllm diff --git a/csrc/tests/test_doc_node.cpp b/csrc/tests/test_doc_node.cpp index 655c35e20..cc8a27d5a 100644 --- a/csrc/tests/test_doc_node.cpp +++ b/csrc/tests/test_doc_node.cpp @@ -51,13 +51,13 @@ TEST(doc_node, metadata) { {"alpha", std::string("A")}, {"beta", std::string("B")}, }; - EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::ALL), "alpha:A\nbeta:B"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::ALL), "alpha: A\nbeta: B"); node.set_excluded_embed_metadata_keys({"beta"}); - EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::EMBED), "alpha:A"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::EMBED), "alpha: A"); node.set_excluded_llm_metadata_keys({"alpha"}); - EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::LLM), "beta:B"); + EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::LLM), "beta: B"); EXPECT_EQ(node.get_metadata_string(lazyllm::MetadataMode::NONE), ""); @@ -81,7 +81,7 @@ TEST(doc_node, text) { lazyllm::DocNode node("body"); node._metadata = lazyllm::DocNode::Metadata{{"alpha", std::string("A")}}; EXPECT_EQ(node.get_text(lazyllm::MetadataMode::NONE), "body"); - EXPECT_EQ(node.get_text(lazyllm::MetadataMode::EMBED), "alpha:A\n\nbody"); + EXPECT_EQ(node.get_text(lazyllm::MetadataMode::EMBED), "alpha: A\n\nbody"); } TEST(doc_node, relationships) { @@ -147,7 +147,7 @@ TEST(doc_node, py_do_embedding_writes_embedding_vector) { ASSERT_TRUE(node._embedding_vecs.find("len_embedding") != node._embedding_vecs.end()); ASSERT_EQ(node._embedding_vecs["len_embedding"].size(), 1u); - EXPECT_EQ(node._embedding_vecs["len_embedding"][0], 6.0); // "\n\ntext" + EXPECT_EQ(node._embedding_vecs["len_embedding"][0], 4.0); // "text" } TEST(doc_node, equality_uses_uid) { diff --git a/csrc/tests/test_utils.cpp b/csrc/tests/test_utils.cpp index 58660e91b..9b0850865 100644 --- a/csrc/tests/test_utils.cpp +++ b/csrc/tests/test_utils.cpp @@ -74,3 +74,21 @@ TEST(utils, rag_metadata_keys_constants_are_exposed) { EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_PATH, "lazyllm_doc_path"); EXPECT_EQ(lazyllm::RAGMetadataKeys::DOC_ID, "docid"); } + +TEST(utils, any_to_string_formats_scalar_metadata_values) { + EXPECT_EQ(lazyllm::any_to_string(lazyllm::MetadataVType(std::string("alpha"))), "alpha"); + EXPECT_EQ(lazyllm::any_to_string(lazyllm::MetadataVType(7)), "7"); + EXPECT_EQ(lazyllm::any_to_string(lazyllm::MetadataVType(3.5)), "3.5"); +} + +TEST(utils, any_to_string_formats_vector_metadata_values_with_brackets) { + EXPECT_EQ( + lazyllm::any_to_string(lazyllm::MetadataVType(std::vector{"a", "b"})), + "[a,b]"); + EXPECT_EQ( + lazyllm::any_to_string(lazyllm::MetadataVType(std::vector{1, 2, 3})), + "[1,2,3]"); + EXPECT_EQ( + lazyllm::any_to_string(lazyllm::MetadataVType(std::vector{1.5, 2.0})), + "[1.5,2]"); +}