Skip to content

Commit ba13bb4

Browse files
committed
refactor: clean up conditional compilation directives in model headers.
1 parent d98b59e commit ba13bb4

File tree

6 files changed

+21
-107
lines changed

6 files changed

+21
-107
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ limitations under the License.
2020
#include "graph/types.h"
2121
#include "layers/npu/npu_lm_head_impl.h"
2222
#include "layers/npu/npu_word_embedding_impl.h"
23-
#else
24-
#include "layers/lm_head.h"
25-
#include "layers/word_embedding.h"
2623
#endif
2724
// clang-format on
2825
#include <c10/core/Device.h>
@@ -34,12 +31,13 @@ limitations under the License.
3431
#include "core/framework/model_loader.h"
3532
#include "core/framework/quant_args.h"
3633
#include "core/framework/state_dict/state_dict.h"
34+
#include "layers/lm_head.h"
35+
#include "layers/word_embedding.h"
3736
#include "model_args.h"
3837
#include "model_input_params.h"
3938

4039
namespace xllm {
4140

42-
#if !defined(USE_NPU)
4341
namespace detail {
4442
template <typename T, typename = void>
4543
struct has_get_lm_head : std::false_type {};
@@ -76,7 +74,6 @@ struct has_set_word_embedding<
7674
std::void_t<decltype(std::declval<T>()->set_word_embedding(
7775
std::declval<layer::WordEmbedding&>()))>> : std::true_type {};
7876
} // namespace detail
79-
#endif
8077

8178
class CausalLM : public torch::nn::Module {
8279
public:
@@ -113,7 +110,7 @@ class CausalLM : public torch::nn::Module {
113110
virtual void set_npu_lm_head(layer::NpuLmHead& head) = 0;
114111
virtual layer::NpuWordEmbedding get_npu_word_embedding() = 0;
115112
virtual void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) = 0;
116-
#else
113+
#endif
117114
virtual layer::LmHead get_lm_head() {
118115
LOG(FATAL)
119116
<< "Method 'get_lm_head' is not implemented/supported by this model.";
@@ -130,7 +127,6 @@ class CausalLM : public torch::nn::Module {
130127
LOG(FATAL) << "Method 'set_word_embedding' is not implemented/supported by "
131128
"this model.";
132129
}
133-
#endif
134130
};
135131

136132
template <typename Model>
@@ -180,7 +176,7 @@ class CausalLMImpl : public CausalLM {
180176
void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override {
181177
model_->set_npu_word_embedding(embedding);
182178
}
183-
#else
179+
#endif
184180
layer::LmHead get_lm_head() override {
185181
if constexpr (detail::has_get_lm_head<Model>::value) {
186182
return model_->get_lm_head();
@@ -212,7 +208,6 @@ class CausalLMImpl : public CausalLM {
212208
CausalLM::set_word_embedding(embedding);
213209
}
214210
}
215-
#endif
216211

217212
torch::Device device() const override { return options_.device(); }
218213

xllm/core/framework/model/causal_vlm.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class CausalVLMImpl : public CausalVLM {
7979
void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override {
8080
model_->set_npu_word_embedding(embedding);
8181
}
82-
#else
82+
#endif
8383
layer::LmHead get_lm_head() override {
8484
if constexpr (detail::has_get_lm_head<Model>::value) {
8585
return model_->get_lm_head();
@@ -111,7 +111,6 @@ class CausalVLMImpl : public CausalVLM {
111111
CausalLM::set_word_embedding(embedding);
112112
}
113113
}
114-
#endif
115114

116115
torch::Device device() const override { return options_.device(); }
117116

xllm/core/framework/model/mm_embedding_vlm.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,11 @@ class MMEmbeddingVLMImpl : public MMEmbeddingVLM {
7272
virtual void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) {
7373
return;
7474
}
75-
#else
75+
#endif
7676
virtual void set_lm_head(layer::LmHead& head) { return; }
7777
virtual layer::LmHead get_lm_head() { return nullptr; }
7878
virtual layer::WordEmbedding get_word_embedding() { return nullptr; }
7979
virtual void set_word_embedding(layer::WordEmbedding& embedding) { return; }
80-
#endif
8180

8281
void load_model(std::unique_ptr<ModelLoader> loader) override {
8382
model_->load_model(std::move(loader));

xllm/core/layers/config.h

Lines changed: 10 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -39,103 +39,26 @@ limitations under the License.
3939
} \
4040
}
4141

42-
#if defined(USE_NPU)
43-
#include "npu/npu_word_embedding_impl.h"
44-
#else
45-
#include "common/word_embedding_impl.h"
46-
#endif
47-
48-
#if defined(USE_NPU)
49-
#include "npu/npu_pos_embedding_impl.h"
50-
#else
51-
#include "common/rotary_embedding.h"
52-
#endif
53-
54-
#if defined(USE_NPU)
55-
#include "npu/npu_lm_head_impl.h"
56-
#else
5742
#include "common/linear.h"
58-
UNIFY_CLASS_NAME(ColumnParallelLinearImpl, LmHeadImpl)
59-
#endif
43+
#include "common/qwen2_5_vision_layer.h"
44+
#include "common/qwen2_decoder_layer.h"
45+
#include "common/qwen3_moe_decoder_layer.h"
46+
#include "common/rotary_embedding.h"
47+
#include "common/word_embedding_impl.h"
6048

61-
#if defined(USE_NPU)
62-
#include "npu/npu_deepseek_v2_decoder_layer_impl.h"
63-
#elif defined(USE_MLU)
49+
#if defined(USE_MLU)
6450
#include "mlu/deepseek_v2_decoder_layer_impl.h"
6551
#else
6652
REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV2DecoderLayerImpl);
6753
#endif
6854

69-
#if defined(USE_NPU)
70-
#include "npu/npu_deepseek_v32_decoder_layer_impl.h"
71-
#else
72-
REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV32DecoderLayerImpl);
73-
#endif
74-
75-
#if defined(USE_NPU)
76-
#include "npu/npu_llama_decoder_layer_impl.h"
77-
#else
78-
REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl);
79-
#endif
80-
81-
#if defined(USE_NPU)
82-
#include "npu/npu_qwen2_decoder_layer_impl.h"
83-
#else
84-
#include "common/qwen2_decoder_layer.h"
85-
#endif
86-
87-
#if defined(USE_NPU)
88-
#include "npu/npu_qwen2_vision_encoder_layer_impl.h"
89-
#else
90-
#include "common/qwen2_5_vision_layer.h"
55+
UNIFY_CLASS_NAME(ColumnParallelLinearImpl, LmHeadImpl)
9156
UNIFY_CLASS_NAME(Qwen2_VisionLayerImpl, Qwen2VisionEncoderLayerImpl)
92-
#endif
93-
94-
#if defined(USE_NPU)
95-
#include "npu/npu_qwen2dot5_vision_encoder_layer_impl.h"
96-
#else
97-
#include "common/qwen2_5_vision_layer.h"
9857
UNIFY_CLASS_NAME(Qwen2_5_VisionLayerImpl, Qwen2dot5VisionEncoderLayerImpl)
99-
#endif
100-
101-
#if defined(USE_NPU)
102-
#include "npu/npu_qwen3_decoder_layer_impl.h"
103-
#else
104-
#include "common/qwen2_decoder_layer.h"
105-
#endif
106-
107-
#if defined(USE_NPU)
108-
#include "npu/npu_qwen3_moe_decoder_layer_impl.h"
109-
#else
110-
#include "common/qwen3_moe_decoder_layer.h"
111-
#endif
112-
113-
#if defined(USE_NPU)
114-
#include "npu/npu_qwen3_vision_encoder_layer_impl.h"
115-
#else
116-
#include "common/qwen2_5_vision_layer.h"
11758
UNIFY_CLASS_NAME(Qwen3_VisionLayerImpl, Qwen3VisionEncoderLayerImpl)
118-
#endif
11959

120-
#if defined(USE_NPU)
121-
#include "npu/npu_siglip_encoder_layer_impl.h"
122-
#else
60+
REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV32DecoderLayerImpl);
61+
REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl);
12362
REGISTER_NOT_IMPLEMENTED_CLASS(SiglipEncoderLayerImpl);
124-
#endif
125-
126-
#if defined(USE_NPU)
127-
#include "npu/npu_glm4_decoder_layer_impl.h"
128-
#else
12963
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4DecoderLayerImpl);
130-
#endif
131-
132-
#if defined(USE_NPU)
133-
#include "npu/npu_glm4_vision_encoder_layer_impl.h"
134-
namespace xllm {
135-
namespace layer {
136-
using Glm4VisionEncoderLayerImpl = NpuGlm4VisionEncoderLayerImpl;
137-
}
138-
} // namespace xllm
139-
#else
140-
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl);
141-
#endif
64+
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl);

xllm/core/runtime/llm_worker_impl.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class LLMWorkerImpl : public WorkerImpl {
5959
model_->set_npu_word_embedding(embedding);
6060
};
6161

62-
#else
62+
#endif
6363
layer::LmHead get_lm_head() { return model_->get_lm_head(); };
6464

6565
void set_lm_head(layer::LmHead& head) { model_->set_lm_head(head); };
@@ -72,8 +72,6 @@ class LLMWorkerImpl : public WorkerImpl {
7272
model_->set_word_embedding(embedding);
7373
};
7474

75-
#endif
76-
7775
private:
7876
std::unique_ptr<BeamSearcher> beam_searcher_;
7977
};

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ bool SpeculativeWorkerImpl::init_model(const std::string& model_weights_path,
195195
auto word_embedding = impl_->get_npu_word_embedding();
196196
draft_impl_->set_npu_word_embedding(word_embedding);
197197
} else {
198-
// TODO: Support TORCH backend via torch_npu encapsulation in the future.
199-
// Currently, it is explicitly disabled.
200-
LOG(FATAL)
201-
<< "SpeculativeWorkerImpl::init_model not support TORCH backend";
198+
auto head = impl_->get_lm_head();
199+
draft_impl_->set_lm_head(head);
200+
auto word_embedding = impl_->get_word_embedding();
201+
draft_impl_->set_word_embedding(word_embedding);
202202
}
203203
#else
204204
auto head = impl_->get_lm_head();

0 commit comments

Comments
 (0)