1+ #include " ernie.h"
2+
3+ namespace chatllm ::ernie::dense
4+ {
5+ class ChatHistoryEncoder : public BaseHistoryEncoder
6+ {
7+ public:
8+ void append_sys_prompt (std::vector<int > &ids) const override ;
9+ void append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const override ;
10+ void append_user (int round_idx, const std::string &user, std::vector<int > &ids) const override ;
11+ void append_ai_opening (int round_idx, std::vector<int > &ids) const override ;
12+ };
13+
14+ static ChatHistoryEncoder _chat_encoder;
15+
16+ Tokenizer::Tokenizer (const Config &config)
17+ : chatllm::llama::v2::Tokenizer(config, &_chat_encoder)
18+ {}
19+
20+ void ChatHistoryEncoder::append_ai (int round_idx, const std::string &ai, std::vector<int > &ids) const
21+ {
22+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
23+ append_ai_opening (round_idx, ids);
24+ tok->encode (ai, ids, false , true );
25+ }
26+
27+ void ChatHistoryEncoder::append_sys_prompt (std::vector<int > &ids) const
28+ {
29+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
30+ std::ostringstream oss_prompt;
31+
32+ ids.push_back (tok->bos_token_id );
33+ if (tok->get_system_prompt ().size () > 0 )
34+ {
35+ oss_prompt << tok->get_system_prompt () << " \n " ;
36+ auto text = oss_prompt.str ();
37+ tok->encode (text, ids);
38+ }
39+ }
40+
41+ void ChatHistoryEncoder::append_user (int round_idx, const std::string &user, std::vector<int > &ids) const
42+ {
43+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
44+ std::ostringstream oss_prompt;
45+
46+ oss_prompt << " User: " + user << " \n " ;
47+ auto text = oss_prompt.str ();
48+ tok->encode (text, ids);
49+ }
50+
51+ void ChatHistoryEncoder::append_ai_opening (int round_idx, std::vector<int > &ids) const
52+ {
53+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
54+ tok->encode (" Assistant: " , ids);
55+ }
56+
57+ ConditionalGeneration::ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type)
58+ : chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>(config, runtime_config, type,
59+ config.num_key_value_heads, config.head_dim, config.max_length, 12 , config.tie_word_embeddings != 0 )
60+ {
61+ auto transformer = Base::get_typed_transformer<ModelClass2>();
62+ for (int i = 0 ; i < config.num_hidden_layers ; i++)
63+ {
64+ auto &attention = transformer->layers [i].attention ;
65+ attention.freq_base = config.rope_theta ;
66+ }
67+ }
68+ }
0 commit comments