Skip to content

Commit 1899886

Browse files
committed
support ERNIE-4.5 dense
1 parent 50e24bb commit 1899886

File tree

11 files changed

+1108
-829
lines changed

11 files changed

+1108
-829
lines changed

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ set(core_files src/backend.cpp
5555
src/unicode-data.cpp
5656
src/vision_process.cpp
5757
src/audio_process.cpp
58+
models/ernie.cpp
59+
models/hunyuan.cpp
60+
models/llama.cpp
5861
models/qwen.cpp
59-
models/hunyuan.cpp)
62+
)
6063

6164
add_library(libchatllm SHARED EXCLUDE_FROM_ALL src/main.cpp ${core_files})
6265
target_link_libraries(libchatllm PRIVATE ggml)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16-
* 2025-06-30: Hunyuan-A13B
16+
* 2025-06-30: Hunyuan-A13B, ERNIE-Dense
1717
* 2025-06-21: [I can hear](./docs/multimodal.md): Qwen2-Audio
1818
* 2025-06-10: SmolVLM2
1919
* 2025-06-07: MiniCPM4

convert.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class ModelType(Enum):
166166
Exaone = 0x1705
167167
DeepSeek_R1_Distill_LlaMA = 0x1706
168168
Aquila2 = 0x1707
169+
ERNIE_DENSE = 0x1708
169170

170171
StarCoder2 = 0x1800
171172

@@ -1359,6 +1360,34 @@ def dump_config(f, config, ggml_type):
13591360
def get_weight_names(config):
13601361
return LlamaConverter.get_weight_names(config)
13611362

1363+
class ERNIEDenseConverter(BaseConverter):
1364+
MODEL_TYPE = ModelType.ERNIE_DENSE
1365+
1366+
@classmethod
1367+
def pp(cls, config, name: str, tensor):
1368+
return Llama3Converter.pp(config, name, tensor)
1369+
1370+
@staticmethod
1371+
def dump_config(f, config, ggml_type):
1372+
if config.rope_scaling is not None:
1373+
assert config.rope_scaling == 1.0, 'rope_scaling must equal to 1.0'
1374+
1375+
dump_llama_like_config(f, config, ggml_type)
1376+
config_values = [
1377+
config.num_key_value_heads,
1378+
config.head_dim,
1379+
1 if config.tie_word_embeddings else 0,
1380+
]
1381+
f.write(struct.pack("i" * len(config_values), *config_values))
1382+
f.write(struct.pack("<f", config.rope_theta))
1383+
1384+
@staticmethod
1385+
def get_weight_names(config):
1386+
weight_names = Llama3Converter.get_weight_names(config)
1387+
if (config.tie_word_embeddings is not None) and config.tie_word_embeddings:
1388+
weight_names.remove('lm_head.weight')
1389+
return weight_names
1390+
13621391
class Llama31Converter(BaseConverter):
13631392
MODEL_TYPE = ModelType.LlaMA31
13641393

@@ -7485,6 +7514,8 @@ def main():
74857514
AprielConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
74867515
elif arch in ['Qwen3MoeForCausalLM', 'Qwen3ForCausalLM']:
74877516
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
7517+
elif arch == 'Ernie4_5_ForCausalLM':
7518+
ERNIEDenseConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
74887519
elif arch == 'deepseek-r1-distill-qwen3':
74897520
QWen3Converter.MODEL_TYPE = ModelType.DeepSeek_R1_Distill_QWen3
74907521
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)

docs/models.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858

5959
Two optimization modes are defined: speed (default) and memory. See `BaseMLAttention`.
6060

61+
* ERNIE (`Ernie4_5_ForCausalLM`)
62+
* [x] [0.3B](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT/tree/c163aa422d265f995b024d1322d91c4e3cb52ec8)
63+
6164
* EXAONE (`ExaoneForCausalLM`)
6265
* [x] v3.5: [Instruct-2.4B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct), [Instruct-7.8B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct), [Instruct-32B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-32B-Instruct)
6366
* [x] Deep: [2.4B](https://huggingface.co/LGAI-EXAONE/EXAONE-Deep-2.4B/tree/b9e0d963cc9be39abce33381f40a8da4324cf4bb), [7.8B](https://huggingface.co/LGAI-EXAONE/EXAONE-Deep-7.8B/tree/19948cbbd0e9afb0f7b5a918eb7e2eb18341e076), [32B](https://huggingface.co/LGAI-EXAONE/EXAONE-Deep-32B/tree/dfa797fc8d8ae6ecc0e5f7a450317cc1433b2545)

models/ernie.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include "ernie.h"
2+
3+
namespace chatllm::ernie::dense
4+
{
5+
class ChatHistoryEncoder : public BaseHistoryEncoder
6+
{
7+
public:
8+
void append_sys_prompt(std::vector<int> &ids) const override;
9+
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
10+
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
11+
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
12+
};
13+
14+
static ChatHistoryEncoder _chat_encoder;
15+
16+
Tokenizer::Tokenizer(const Config &config)
17+
: chatllm::llama::v2::Tokenizer(config, &_chat_encoder)
18+
{}
19+
20+
void ChatHistoryEncoder::append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const
21+
{
22+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
23+
append_ai_opening(round_idx, ids);
24+
tok->encode(ai, ids, false, true);
25+
}
26+
27+
void ChatHistoryEncoder::append_sys_prompt(std::vector<int> &ids) const
28+
{
29+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
30+
std::ostringstream oss_prompt;
31+
32+
ids.push_back(tok->bos_token_id);
33+
if (tok->get_system_prompt().size() > 0)
34+
{
35+
oss_prompt << tok->get_system_prompt() << "\n";
36+
auto text = oss_prompt.str();
37+
tok->encode(text, ids);
38+
}
39+
}
40+
41+
void ChatHistoryEncoder::append_user(int round_idx, const std::string &user, std::vector<int> &ids) const
42+
{
43+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
44+
std::ostringstream oss_prompt;
45+
46+
oss_prompt << "User: " + user << "\n";
47+
auto text = oss_prompt.str();
48+
tok->encode(text, ids);
49+
}
50+
51+
void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids) const
52+
{
53+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
54+
tok->encode("Assistant: ", ids);
55+
}
56+
57+
ConditionalGeneration::ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type)
58+
: chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>(config, runtime_config, type,
59+
config.num_key_value_heads, config.head_dim, config.max_length, 12, config.tie_word_embeddings != 0)
60+
{
61+
auto transformer = Base::get_typed_transformer<ModelClass2>();
62+
for (int i = 0; i < config.num_hidden_layers; i++)
63+
{
64+
auto &attention = transformer->layers[i].attention;
65+
attention.freq_base = config.rope_theta;
66+
}
67+
}
68+
}

models/ernie.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include "../src/models.h"
4+
#include "../src/models_priv.h"
5+
6+
#include "llama.h"
7+
8+
namespace chatllm::ernie::dense
9+
{
10+
struct Config : public chatllm::llama::v2::Config
11+
{
12+
int num_key_value_heads;
13+
int head_dim;
14+
int tie_word_embeddings;
15+
float rope_theta;
16+
};
17+
18+
class Tokenizer : public chatllm::llama::v2::Tokenizer
19+
{
20+
public:
21+
Tokenizer(const Config &config);
22+
};
23+
24+
class ConditionalGeneration : public chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>
25+
{
26+
public:
27+
ConditionalGeneration() = default;
28+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_ERNIE_DENSE);
29+
};
30+
}

0 commit comments

Comments
 (0)