Skip to content

Commit 4ba5d78

Browse files
committed
refactor: update causal LM implementations to inherit from LlmForCausalLMImplBase.
1 parent 0e40473 commit 4ba5d78

File tree

6 files changed

+36
-357
lines changed

6 files changed

+36
-357
lines changed

xllm/models/llm/npu/deepseek_v2.h

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ limitations under the License.
3434
#include "core/layers/npu/npu_rms_norm_impl.h"
3535
#include "core/layers/npu/npu_word_embedding_impl.h"
3636
#include "core/layers/npu/rotary_embedding.h"
37+
#include "llm_model_base.h"
3738
#include "models/model_registry.h"
3839
// DeepSeek v2 compatible with huggingface weights
3940
// ref to:
@@ -255,72 +256,25 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
255256
};
256257
TORCH_MODULE(DeepseekV2Model);
257258

258-
class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
259+
class DeepseekV2ForCausalLMImpl
260+
: public LlmForCausalLMImplBase<DeepseekV2Model> {
259261
public:
260-
DeepseekV2ForCausalLMImpl(const ModelContext& context) {
261-
model_ = register_module("model", DeepseekV2Model(context));
262-
npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
263-
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
264-
}
265-
266-
// tokens: [num_tokens]
267-
// positions: [num_tokens] token pos in the sequence
268-
// returns: [num_tokens, hidden_size]
269-
torch::Tensor forward(const torch::Tensor& tokens,
270-
const torch::Tensor& positions,
271-
std::vector<KVCache>& kv_caches,
272-
const ModelInputParams& input_params) {
273-
return model_(tokens, positions, kv_caches, input_params);
274-
}
275-
276-
// hidden_states: [num_tokens, hidden_size]
277-
// seleted_idxes: [num_tokens]
278-
// returns: [num_tokens, vocab_size]
279-
torch::Tensor logits(const torch::Tensor& hidden_states,
280-
const torch::Tensor& seleted_idxes) {
281-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
282-
}
283-
284-
void load_model(std::unique_ptr<ModelLoader> loader) {
285-
for (const auto& state_dict : loader->get_state_dicts()) {
286-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
287-
npu_lm_head_->load_state_dict(
288-
state_dict->get_dict_with_prefix("lm_head."));
289-
}
290-
291-
// verify
292-
model_->verify_loaded_weights("model.");
293-
npu_lm_head_->verify_loaded_weights("lm_head.");
294-
295-
model_->merge_loaded_weights();
296-
npu_lm_head_->merge_loaded_weights();
297-
}
262+
DeepseekV2ForCausalLMImpl(const ModelContext& context)
263+
: LlmForCausalLMImplBase<DeepseekV2Model>(context),
264+
first_k_dense_replace_(
265+
context.get_model_args().first_k_dense_replace()) {}
298266

299267
void prepare_expert_weight(int32_t layer_id,
300-
const std::vector<int32_t>& expert_ids) {
268+
const std::vector<int32_t>& expert_ids) override {
301269
model_->prepare_expert_weight(layer_id + first_k_dense_replace_,
302270
expert_ids);
303271
}
304272

305-
void update_expert_weight(int32_t layer_id) {
273+
void update_expert_weight(int32_t layer_id) override {
306274
model_->update_expert_weight(layer_id + first_k_dense_replace_);
307275
}
308276

309-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
310-
311-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
312-
313-
layer::NpuWordEmbedding get_npu_word_embedding() {
314-
return model_->get_npu_word_embedding();
315-
}
316-
317-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
318-
model_->set_npu_word_embedding(npu_word_embedding);
319-
}
320-
321277
private:
322-
DeepseekV2Model model_{nullptr};
323-
layer::NpuLmHead npu_lm_head_{nullptr};
324278
int32_t first_k_dense_replace_;
325279
};
326280
TORCH_MODULE(DeepseekV2ForCausalLM);

xllm/models/llm/npu/deepseek_v2_mtp.h

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include "core/layers/npu/npu_column_parallel_linear_impl.h"
1919
#include "deepseek_v2.h"
20+
#include "llm_model_base.h"
2021

2122
// DeepSeek v2 compatible with huggingface weights
2223
// ref to:
@@ -196,67 +197,22 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module {
196197
};
197198
TORCH_MODULE(DeepseekV2MtpModel);
198199

199-
class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module {
200+
class DeepseekV2MtpForCausalLMImpl
201+
: public LlmForCausalLMImplBase<DeepseekV2MtpModel> {
200202
public:
201-
DeepseekV2MtpForCausalLMImpl(const ModelContext& context) {
202-
model_ = register_module("model", DeepseekV2MtpModel(context));
203-
}
204-
205-
// tokens: [num_tokens]
206-
// positions: [num_tokens] token pos in the sequence
207-
// returns: [num_tokens, hidden_size]
208-
torch::Tensor forward(const torch::Tensor& tokens,
209-
const torch::Tensor& positions,
210-
std::vector<KVCache>& kv_caches,
211-
const ModelInputParams& input_params) {
212-
return model_(tokens, positions, kv_caches, input_params);
213-
}
203+
DeepseekV2MtpForCausalLMImpl(const ModelContext& context)
204+
: LlmForCausalLMImplBase<DeepseekV2MtpModel>(context) {}
214205

215-
// hidden_states: [num_tokens, hidden_size]
216-
// seleted_idxes: [num_tokens]
217-
// returns: [num_tokens, vocab_size]
218-
torch::Tensor logits(const torch::Tensor& hidden_states,
219-
const torch::Tensor& seleted_idxes) {
220-
// select tokens if provided
221-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
222-
}
223-
224-
// load model
225-
void load_model(std::unique_ptr<ModelLoader> loader) {
206+
void load_model(std::unique_ptr<ModelLoader> loader,
207+
std::string prefix = "model.") override {
226208
for (const auto& state_dict : loader->get_state_dicts()) {
227-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
228-
// npu_lm_head_->load_state_dict(state_dict.get_dict_with_prefix("model.shared_head.head."));
209+
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
229210
}
230211

231-
// verify
232-
model_->verify_loaded_weights("model.");
233-
// npu_lm_head_->verify_loaded_weights("model.shared_head.head.");
212+
model_->verify_loaded_weights(prefix);
234213

235214
model_->merge_loaded_weights();
236-
// npu_lm_head_->merge_loaded_weights();
237-
}
238-
239-
void prepare_expert_weight(int32_t layer_id,
240-
const std::vector<int32_t>& expert_ids) {
241-
return;
242-
}
243-
void update_expert_weight(int32_t layer_id) { return; }
244-
245-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
246-
247-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
248-
249-
layer::NpuWordEmbedding get_npu_word_embedding() {
250-
return model_->get_npu_word_embedding();
251215
}
252-
253-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
254-
model_->set_npu_word_embedding(npu_word_embedding);
255-
}
256-
257-
private:
258-
DeepseekV2MtpModel model_{nullptr};
259-
layer::NpuLmHead npu_lm_head_{nullptr};
260216
};
261217
TORCH_MODULE(DeepseekV2MtpForCausalLM);
262218

xllm/models/llm/npu/glm4_moe.h

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -285,74 +285,10 @@ class Glm4MoeModelImpl : public torch::nn::Module {
285285
};
286286
TORCH_MODULE(Glm4MoeModel);
287287

288-
class Glm4MoeForCausalLMImpl : public torch::nn::Module {
288+
class Glm4MoeForCausalLMImpl : public LlmForCausalLMImplBase<Glm4MoeModel> {
289289
public:
290-
Glm4MoeForCausalLMImpl(const ModelContext& context) {
291-
model_ = register_module("model", Glm4MoeModel(context));
292-
npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
293-
}
294-
295-
torch::Tensor get_input_embeddings(torch::Tensor input_ids) {
296-
return model_->get_input_embeddings(input_ids);
297-
}
298-
299-
// tokens: [num_tokens]
300-
// positions: [num_tokens] token pos in the sequence
301-
// returns: [num_tokens, hidden_size]
302-
torch::Tensor forward(const torch::Tensor& tokens,
303-
const torch::Tensor& positions,
304-
std::vector<KVCache>& kv_caches,
305-
const ModelInputParams& input_params) {
306-
return model_(tokens, positions, kv_caches, input_params);
307-
}
308-
309-
// hidden_states: [num_tokens, hidden_size]
310-
// seleted_idxes: [num_tokens]
311-
// returns: [num_tokens, vocab_size]
312-
torch::Tensor logits(const torch::Tensor& hidden_states,
313-
const torch::Tensor& seleted_idxes) {
314-
// select tokens if provided
315-
auto h = hidden_states;
316-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
317-
}
318-
319-
void load_model(std::unique_ptr<ModelLoader> loader,
320-
std::string prefix = "model." /*llm model weight prefix*/) {
321-
for (const auto& state_dict : loader->get_state_dicts()) {
322-
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
323-
npu_lm_head_->load_state_dict(
324-
state_dict->get_dict_with_prefix("lm_head."));
325-
}
326-
327-
// verify
328-
model_->verify_loaded_weights(prefix);
329-
npu_lm_head_->verify_loaded_weights("lm_head.");
330-
331-
model_->merge_loaded_weights();
332-
npu_lm_head_->merge_loaded_weights();
333-
}
334-
335-
virtual void prepare_expert_weight(int32_t layer_id,
336-
const std::vector<int32_t>& expert_ids) {
337-
return;
338-
}
339-
virtual void update_expert_weight(int32_t layer_id) { return; }
340-
341-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
342-
343-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
344-
345-
layer::NpuWordEmbedding get_npu_word_embedding() {
346-
return model_->get_npu_word_embedding();
347-
}
348-
349-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
350-
model_->set_npu_word_embedding(npu_word_embedding);
351-
}
352-
353-
private:
354-
Glm4MoeModel model_{nullptr};
355-
layer::NpuLmHead npu_lm_head_{nullptr};
290+
Glm4MoeForCausalLMImpl(const ModelContext& context)
291+
: LlmForCausalLMImplBase<Glm4MoeModel>(context) {}
356292
};
357293
TORCH_MODULE(Glm4MoeForCausalLM);
358294

xllm/models/llm/npu/glm4_moe_mtp.h

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -236,67 +236,23 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
236236
};
237237
TORCH_MODULE(Glm4MoeMtpModel);
238238

239-
class Glm4MoeMtpForCausalLMImpl : public torch::nn::Module {
239+
class Glm4MoeMtpForCausalLMImpl
240+
: public LlmForCausalLMImplBase<Glm4MoeMtpModel> {
240241
public:
241-
Glm4MoeMtpForCausalLMImpl(const ModelContext& context) {
242-
model_ = register_module("model", Glm4MoeMtpModel(context));
243-
}
244-
245-
// tokens: [num_tokens]
246-
// positions: [num_tokens] token pos in the sequence
247-
// returns: [num_tokens, hidden_size]
248-
torch::Tensor forward(const torch::Tensor& tokens,
249-
const torch::Tensor& positions,
250-
std::vector<KVCache>& kv_caches,
251-
const ModelInputParams& input_params) {
252-
return model_(tokens, positions, kv_caches, input_params);
253-
}
242+
Glm4MoeMtpForCausalLMImpl(const ModelContext& context)
243+
: LlmForCausalLMImplBase<Glm4MoeMtpModel>(context) {}
254244

255-
// hidden_states: [num_tokens, hidden_size]
256-
// seleted_idxes: [num_tokens]
257-
// returns: [num_tokens, vocab_size]
258-
torch::Tensor logits(const torch::Tensor& hidden_states,
259-
const torch::Tensor& seleted_idxes) {
260-
// select tokens if provided
261-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
262-
}
263-
264-
// load model
265-
void load_model(std::unique_ptr<ModelLoader> loader) {
245+
void load_model(std::unique_ptr<ModelLoader> loader,
246+
std::string prefix = "model.") override {
266247
for (const auto& state_dict : loader->get_state_dicts()) {
267-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
268-
// npu_lm_head_->load_state_dict(state_dict.get_dict_with_prefix("model.shared_head.head."));
248+
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
269249
}
270250

271251
// verify
272-
model_->verify_loaded_weights("model.");
273-
// npu_lm_head_->verify_loaded_weights("model.shared_head.head.");
252+
model_->verify_loaded_weights(prefix);
274253

275254
model_->merge_loaded_weights();
276-
// npu_lm_head_->merge_loaded_weights();
277-
}
278-
279-
void prepare_expert_weight(int32_t layer_id,
280-
const std::vector<int32_t>& expert_ids) {
281-
return;
282-
}
283-
void update_expert_weight(int32_t layer_id) { return; }
284-
285-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
286-
287-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
288-
289-
layer::NpuWordEmbedding get_npu_word_embedding() {
290-
return model_->get_npu_word_embedding();
291255
}
292-
293-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
294-
model_->set_npu_word_embedding(npu_word_embedding);
295-
}
296-
297-
private:
298-
Glm4MoeMtpModel model_{nullptr};
299-
layer::NpuLmHead npu_lm_head_{nullptr};
300256
};
301257
TORCH_MODULE(Glm4MoeMtpForCausalLM);
302258

0 commit comments

Comments
 (0)