Skip to content

Commit 43e6b00

Browse files
committed
support llama3.
1 parent 253a41a commit 43e6b00

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

include/llm.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,19 @@ class Yi_6b : public Llama2_7b {
313313
virtual std::vector<int> tokenizer(const std::string& query) override;
314314
virtual bool is_stop(int token_id) override;
315315
};
316+
317+
class Llama3_8b : public Llama2_7b {
318+
public:
319+
Llama3_8b() {
320+
model_name_ = "Llama3_8b";
321+
layer_nums_ = 32;
322+
key_value_shape_ = {2, 1, 8, 0, 128};
323+
hidden_size_ = 4096;
324+
}
325+
private:
326+
virtual std::vector<int> tokenizer(const std::string& query) override;
327+
virtual bool is_stop(int token_id) override;
328+
};
316329
// Llm end
317330

318331
// Embedding start

src/llm.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ Llm* Llm::createLLM(const std::string& path, std::string model_type) {
8686
} else if (model_type.find("yi") != std::string::npos) {
8787
llm = new Yi_6b;
8888
llm->model_name_ = "Yi_6b";
89+
} else if (model_type.find("llama3") != std::string::npos) {
90+
llm = new Llama3_8b;
91+
llm->model_name_ = "Llama3_8b";
8992
}
9093
if (!llm) {
9194
std::cerr << "model type can't judge!" << std::endl;
@@ -796,6 +799,18 @@ std::vector<int> Yi_6b::tokenizer(const std::string& query) {
796799
bool Yi_6b::is_stop(int token_id) {
797800
return token_id == 7 || token_id == 64001;
798801
}
802+
803+
std::vector<int> Llama3_8b::tokenizer(const std::string& query) {
804+
// <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n+query+<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n
805+
auto ids = tokenizer_encode(query);
806+
ids.insert(ids.begin(), {128000, 128006, 882, 128007, 271});
807+
ids.insert(ids.end(), {128009, 128006, 78191, 128007, 271});
808+
return ids;
809+
}
810+
811+
bool Llama3_8b::is_stop(int token_id) {
812+
return token_id == 128001 || token_id == 128009;
813+
}
799814
// Llm end
800815

801816
// Embedding start

0 commit comments

Comments
 (0)