55#include < cstring>
66#include < limits>
77#include < regex>
8+ #include < iostream>
89
910#include " unicode.h"
1011#include " chat.h"
@@ -474,6 +475,7 @@ size_t BPEProcessor2::Load(DataReader *data_reader, int n_vocab)
474475 vocab_.id_to_token .resize (piece_size);
475476 load_vocab_merges (vocab_, reader);
476477 build_special_token_cache (vocab_);
478+ searcher.rebuild (vocab_.special_tokens_cache );
477479
478480 return reader.get_total_size ();
479481}
@@ -694,37 +696,115 @@ const std::string BPEProcessor3::IdToPiece(int id) const
694696 }
695697}
696698
697- static std::string search_first_special_token (std::string &input, const _vocab &vocab , int &sp_tok_id )
699+ NearestKeywordSearcher::Node * NearestKeywordSearcher::make_tree (std::vector<Item> &items, char ch , int value )
698700{
699- sp_tok_id = -1 ;
700- auto nearest_match = std::string::npos;
701- for (auto & st: vocab.special_tokens_cache )
702- {
703- const auto & special_id = st.first ;
704- const auto & special_token = st.second ;
701+ Node * r = new Node ();
702+ r->ch = ch;
703+ r->value = items.size () < 1 ? value : -1 ;
705704
706- auto match = input.find (special_token, 0 );
705+ while (true )
706+ {
707+ bool flag = false ;
708+ char tag = 0 ;
709+ int v = -1 ;
707710
708- if (match < nearest_match)
711+ std::vector<Item> sub;
712+ for (int i = (int )items.size () - 1 ; i >= 0 ; i--)
709713 {
710- nearest_match = match;
711- sp_tok_id = special_id;
714+ if (items[i].s .size () < 1 ) continue ;
715+
716+ if (!flag)
717+ {
718+ flag = true ;
719+ tag = items[i].s [0 ];
720+ v = items[i].value ;
721+ }
722+ else
723+ {
724+ if (items[i].s [0 ] != tag) continue ;
725+ }
726+
727+ if (items[i].s .size () > 1 )
728+ sub.emplace_back (items[i].s .substr (1 ), items[i].value );
729+
730+ // mark as visited
731+ items[i].s = " " ;
712732 }
733+
734+ if (!flag) break ;
735+
736+ Node *child = make_tree (sub, tag, v);
737+ r->child .emplace_back (std::unique_ptr<Node>(child));
713738 }
714739
715- if (sp_tok_id >= 0 )
740+ std::sort (r->child .begin (), r->child .end (), [](auto &p1, auto &p2) { return p1->ch <= p2->ch ; });
741+
742+ return r;
743+ }
744+
745+ void NearestKeywordSearcher::rebuild (const std::unordered_map<int , std::string> keywords)
746+ {
747+ root.reset (nullptr );
748+
749+ std::vector<Item> sub;
750+
751+ for (auto & st: keywords)
716752 {
717- const auto & special_token = vocab.special_tokens_cache .at (sp_tok_id);
718- std::string r = input.substr (0 , nearest_match);
719- input = input.substr (nearest_match + special_token.size ());
720- return r;
753+ sub.emplace_back (st.second , st.first );
721754 }
722- else
755+ root.reset (make_tree (sub, 0 , -1 ));
756+ }
757+
758+ int NearestKeywordSearcher::match (const std::string &input, int index, Node *node, int &level) const
759+ {
760+ if (node->child .size () < 1 ) return node->value ;
761+ if (index >= (int )input.size ()) return -1 ;
762+ const char ch = input[index];
763+
764+ int low = 0 ;
765+ int high = (int )node->child .size () - 1 ;
766+ while (high >= low)
767+ {
768+ // assuming no overflow
769+ int middle = (high + low) / 2 ;
770+ Node *n = node->child [middle].get ();
771+ if (n->ch < ch)
772+ {
773+ low = middle + 1 ;
774+ }
775+ else if (ch < n->ch )
776+ {
777+ high = middle - 1 ;
778+ }
779+ else
780+ {
781+ level++;
782+ return match (input, index + 1 , n, level);
783+ }
784+ }
785+
786+ return -1 ;
787+ }
788+
789+ std::string NearestKeywordSearcher::search (std::string &input, int &kw_id) const
790+ {
791+ int index = 0 ;
792+ while (index < (int )input.size ())
723793 {
724- std::string r (input);
725- input = " " ;
726- return r;
794+ int len = 0 ;
795+ kw_id = match (input, index, root.get (), len);
796+ if (kw_id >= 0 )
797+ {
798+ std::string r = input.substr (0 , index);
799+ input = input.substr (index + len);
800+ return r;
801+ }
802+ index++;
727803 }
804+
805+ std::string r (input);
806+ input = " " ;
807+ return r;
728808}
729809
730810int BPEProcessor2::DoEncode (const std::string &input,
@@ -734,7 +814,7 @@ int BPEProcessor2::DoEncode(const std::string &input,
734814 int sp_tok_id = -1 ;
735815 while (text.size () > 0 )
736816 {
737- auto leading = search_first_special_token (text, vocab_ , sp_tok_id);
817+ auto leading = searcher. search (text, sp_tok_id);
738818 DoEncode2 (leading, ids);
739819 if (sp_tok_id < 0 ) break ;
740820 ids->push_back (sp_tok_id);
0 commit comments