Skip to content

Commit 9fdf241

Browse files
committed
feat: support FLUX.1-Redux-dev model on npu device.
1 parent 49be0a6 commit 9fdf241

File tree

12 files changed

+1006
-39
lines changed

12 files changed

+1006
-39
lines changed

xllm/core/framework/dit_model_loader.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ bool DiTFolderLoader::load_args(const std::string& model_weights_path) {
8181
return false;
8282
}
8383

84+
if (!load_image_preprocessor_args(model_weights_path)) {
85+
LOG(ERROR) << "Failed to load image preprocess args from "
86+
<< model_weights_path;
87+
return false;
88+
}
89+
8490
return true;
8591
}
8692

@@ -219,6 +225,101 @@ bool DiTFolderLoader::load_tokenizer_args(
219225
return true;
220226
}
221227

228+
bool DiTFolderLoader::load_image_preprocessor_args(
229+
const std::string& model_weights_path) {
230+
// image preprocessor args
231+
JsonReader image_preprocess_reader;
232+
const std::string image_preprocess_file_path =
233+
model_weights_path + "/preprocessor_config.json";
234+
if (image_preprocess_reader.parse(image_preprocess_file_path)) {
235+
LOG(INFO) << "Success to parse image preprocess args file: "
236+
<< image_preprocess_file_path;
237+
args_.mm_image_do_center_crop() =
238+
image_preprocess_reader.value_or<bool>("do_center_crop", false);
239+
args_.mm_image_crop_height_size() =
240+
image_preprocess_reader.value_or<int>("crop_size.height", 335);
241+
args_.mm_image_crop_width_size() =
242+
image_preprocess_reader.value_or<int>("crop_size.width", 335);
243+
244+
args_.mm_image_size_height() =
245+
image_preprocess_reader.value_or<int>("size.height", 384);
246+
247+
args_.mm_image_size_width() =
248+
image_preprocess_reader.value_or<int>("size.width", 384);
249+
250+
args_.mm_image_do_resize() =
251+
image_preprocess_reader.value_or<bool>("do_resize", false);
252+
args_.mm_image_resize_shortest_edge() =
253+
image_preprocess_reader.value_or<int>("size.shortest_edge", 335);
254+
args_.mm_image_resample() =
255+
image_preprocess_reader.value_or<int>("resample", 335);
256+
257+
args_.mm_image_do_rescale() =
258+
image_preprocess_reader.value_or<bool>("do_rescale", false);
259+
args_.mm_image_rescale_factor() =
260+
image_preprocess_reader.value_or<double>("rescale_factor", 0);
261+
262+
args_.mm_image_do_normalize() =
263+
image_preprocess_reader.value_or<bool>("do_normalize", false);
264+
265+
const auto& image_prerocess_data = image_preprocess_reader.data();
266+
if (image_preprocess_reader.contains("image_mean")) {
267+
args_.mm_image_normalize_mean() =
268+
image_prerocess_data["image_mean"].get<std::vector<double>>();
269+
}
270+
271+
if (image_preprocess_reader.contains("image_std")) {
272+
args_.mm_image_normalize_std() =
273+
image_prerocess_data["image_std"].get<std::vector<double>>();
274+
}
275+
276+
if (image_preprocess_reader.contains("norm_mean")) {
277+
args_.mm_image_normalize_mean() =
278+
image_prerocess_data["norm_mean"].get<std::vector<double>>();
279+
}
280+
281+
if (image_preprocess_reader.contains("norm_std")) {
282+
args_.mm_image_normalize_std() =
283+
image_prerocess_data["norm_std"].get<std::vector<double>>();
284+
}
285+
286+
args_.mm_image_shortest_edge() =
287+
image_preprocess_reader.value_or<int>("size.shortest_edge", 0);
288+
289+
args_.mm_image_longest_edge() =
290+
image_preprocess_reader.value_or<int>("size.longest_edge", 0);
291+
292+
args_.mm_image_min_pixels() =
293+
image_preprocess_reader.value_or<int>("min_pixels", 0);
294+
295+
args_.mm_image_max_pixels() =
296+
image_preprocess_reader.value_or<int>("max_pixels", 0);
297+
298+
args_.mm_image_patch_size() =
299+
image_preprocess_reader.value_or<int>("patch_size", 0);
300+
301+
args_.mm_image_temporal_patch_size() =
302+
image_preprocess_reader.value_or<int>("temporal_patch_size", 0);
303+
304+
args_.mm_image_merge_size() =
305+
image_preprocess_reader.value_or<int>("merge_size", 0);
306+
307+
args_.mm_image_feature_size() =
308+
image_preprocess_reader.value_or<int>("image_feature_size", 0);
309+
310+
args_.mm_scale_resolution() =
311+
image_preprocess_reader.value_or<int>("scale_resolution", 0);
312+
313+
args_.mm_slice_mode() =
314+
image_preprocess_reader.value_or<bool>("slice_mode", false);
315+
316+
args_.mm_use_image_id() =
317+
image_preprocess_reader.value_or<bool>("use_image_id", false);
318+
}
319+
320+
return true;
321+
}
322+
222323
DiTModelLoader::DiTModelLoader(const std::string& model_root_path)
223324
: model_root_path_(model_root_path) {
224325
if (!std::filesystem::exists(model_root_path_)) {

xllm/core/framework/dit_model_loader.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class DiTFolderLoader {
4343
bool load_args(const std::string& model_weights_path);
4444
bool load_model_args(const std::string& model_weights_path);
4545
bool load_tokenizer_args(const std::string& model_weights_path);
46+
bool load_image_preprocessor_args(const std::string& model_weights_path);
4647

4748
// model args
4849
ModelArgs args_;

xllm/core/framework/model/model_args.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ struct ModelArgs {
287287
// VLM image preprocessor resize
288288
PROPERTY(bool, mm_image_do_resize) = false;
289289
PROPERTY(int, mm_image_resize_shortest_edge) = 336;
290+
PROPERTY(int64_t, mm_image_size_height) = 384;
291+
PROPERTY(int64_t, mm_image_size_width) = 384;
290292

291293
PROPERTY(int, mm_image_resample) = 0;
292294

xllm/core/framework/request/dit_request_state.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ struct DiTGenerationParams {
4141
num_images_per_prompt == other.num_images_per_prompt &&
4242
seed == other.seed &&
4343
max_sequence_length == other.max_sequence_length &&
44-
strength == other.strength;
44+
strength == other.strength &&
45+
prompt_embeds_scale == other.prompt_embeds_scale &&
46+
pooled_prompt_embeds_scale == other.pooled_prompt_embeds_scale;
4547
}
4648

4749
bool operator!=(const DiTGenerationParams& other) const {
@@ -65,6 +67,10 @@ struct DiTGenerationParams {
6567
int32_t max_sequence_length = 512;
6668

6769
float strength = 1.0;
70+
71+
float prompt_embeds_scale = 1.0;
72+
73+
float pooled_prompt_embeds_scale = 1.0;
6874
};
6975

7076
struct DiTInputParams {

xllm/models/dit/clip_text_model.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,13 @@ limitations under the License.
1616

1717
#pragma once
1818

19-
#include <atb/atb_infer.h>
20-
#include <c10/core/ScalarType.h>
2119
#include <torch/torch.h>
2220

23-
#include <regex>
24-
#include <unordered_map>
25-
26-
#include "core/framework/dit_model_loader.h"
27-
#include "core/framework/kv_cache/kv_cache.h"
2821
#include "core/framework/model/model_input_params.h"
2922
#include "core/framework/model_context.h"
30-
#include "core/layers/npu/npu_siglip_encoder_layer_impl.h"
3123
#include "dit_linear.h"
3224
#include "models/model_registry.h"
33-
#include "processors/clip_image_processor.h"
3425
#include "processors/input_processor.h"
35-
#include "processors/pywarpper_image_processor.h"
36-
#include "xllm_kernels/core/include/atb_speed/log.h"
3726

3827
namespace xllm {
3928
// clip_text_model compatible with huggingface weights
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include "pipeline_flux_base.h"
18+
#include "processors/siglip_image_processor.h"
19+
#include "siglip_vision_model.h"
20+
// pipeline_flux_prior_redux compatible with huggingface weights
21+
// ref to:
22+
// https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
23+
24+
namespace xllm {
25+
26+
class ReduxImageEncoderImpl : public torch::nn::Module {
27+
public:
28+
explicit ReduxImageEncoderImpl(const ModelContext& context) {
29+
auto model_args = context.get_model_args();
30+
auto options = context.get_tensor_options();
31+
act_ = register_module("act", torch::nn::Functional(torch::silu));
32+
33+
redux_up_ = register_module("redux_up",
34+
DiTLinear(model_args.mm_hidden_size(),
35+
model_args.mm_intermediate_size() * 3,
36+
true));
37+
redux_down_ =
38+
register_module("redux_down",
39+
DiTLinear(model_args.mm_intermediate_size() * 3,
40+
model_args.mm_intermediate_size(),
41+
true));
42+
redux_up_->to(options);
43+
redux_down_->to(options);
44+
}
45+
46+
torch::Tensor forward(const torch::Tensor& hidden_states) {
47+
return redux_down_(act_(redux_up_(hidden_states)));
48+
}
49+
50+
void load_model(std::unique_ptr<DiTFolderLoader> loader) {
51+
for (const auto& state_dict : loader->get_state_dicts()) {
52+
redux_up_->load_state_dict(state_dict->get_dict_with_prefix("redux_up."));
53+
redux_up_weight_loaded_ = true;
54+
redux_up_bias_loaded_ = true;
55+
redux_down_->load_state_dict(
56+
state_dict->get_dict_with_prefix("redux_down."));
57+
redux_down_weight_loaded_ = true;
58+
redux_down_bias_loaded_ = true;
59+
}
60+
}
61+
62+
void verify_loaded_weights(const std::string& prefix) const {
63+
CHECK(redux_up_weight_loaded_)
64+
<< "weight is not loaded for " << prefix + "redux_up.weight";
65+
CHECK(redux_up_bias_loaded_)
66+
<< "weight is not loaded for " << prefix + "redux_up.bias";
67+
CHECK(redux_down_weight_loaded_)
68+
<< "weight is not loaded for " << prefix + "redux_down.weight";
69+
CHECK(redux_down_bias_loaded_)
70+
<< "weight is not loaded for " << prefix + "redux_down.bias";
71+
}
72+
73+
private:
74+
DiTLinear redux_up_{nullptr};
75+
DiTLinear redux_down_{nullptr};
76+
77+
torch::nn::Functional act_ = nullptr;
78+
bool redux_up_weight_loaded_ = false;
79+
bool redux_up_bias_loaded_ = false;
80+
bool redux_down_weight_loaded_ = false;
81+
bool redux_down_bias_loaded_ = false;
82+
};
83+
TORCH_MODULE(ReduxImageEncoder);
84+
85+
REGISTER_MODEL_ARGS(ReduxImageEncoder, [&] {
86+
LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16");
87+
LOAD_ARG_OR(mm_hidden_size, "redux_dim", 1152);
88+
LOAD_ARG_OR(mm_intermediate_size, "txt_in_features", 4096);
89+
});
90+
91+
class FluxPriorReduxPipelineImpl : public FluxPipelineBaseImpl {
92+
public:
93+
FluxPriorReduxPipelineImpl(const DiTModelContext& context) {
94+
auto model_args = context.get_model_args("feature_extractor");
95+
options_ = context.get_tensor_options();
96+
image_encoder_ =
97+
SiglipVisionModel(context.get_model_context("image_encoder"));
98+
image_embedder_ =
99+
ReduxImageEncoder(context.get_model_context("image_embedder"));
100+
feature_extractor_ = std::make_unique<SiglipImageProcessor>(model_args);
101+
}
102+
103+
void load_model(std::unique_ptr<DiTModelLoader> loader) {
104+
std::string model_path = loader->model_root_path();
105+
auto image_encoder_loader = loader->take_component_loader("image_encoder");
106+
auto image_embedder_loader =
107+
loader->take_component_loader("image_embedder");
108+
image_encoder_->load_model(std::move(image_encoder_loader));
109+
image_encoder_->to(options_.device());
110+
image_embedder_->load_model(std::move(image_embedder_loader));
111+
image_embedder_->to(options_.device());
112+
}
113+
114+
torch::Tensor encode_image(const torch::Tensor& image,
115+
int64_t num_images_per_prompt) {
116+
auto imgs = feature_extractor_->preprocess(image).to(options_);
117+
auto image_enc_hidden_states = image_encoder_->forward(imgs);
118+
image_enc_hidden_states =
119+
image_enc_hidden_states.repeat_interleave(num_images_per_prompt, 0);
120+
return image_enc_hidden_states;
121+
}
122+
123+
DiTForwardOutput forward(const DiTForwardInput& input) {
124+
const auto& generation_params = input.generation_params;
125+
auto image = input.images.defined() ? std::make_optional(input.images)
126+
: std::nullopt;
127+
auto prompt_embeds = input.prompt_embeds.defined()
128+
? std::make_optional(input.prompt_embeds)
129+
: std::nullopt;
130+
auto pooled_prompt_embeds =
131+
input.pooled_prompt_embeds.defined()
132+
? std::make_optional(input.pooled_prompt_embeds)
133+
: std::nullopt;
134+
auto prompt_embeds_scale = generation_params.prompt_embeds_scale;
135+
auto pooled_prompt_embeds_scale =
136+
generation_params.pooled_prompt_embeds_scale;
137+
std::vector<torch::Tensor> output = forward_(image.value(),
138+
prompt_embeds,
139+
pooled_prompt_embeds,
140+
generation_params.height,
141+
generation_params.width,
142+
prompt_embeds_scale,
143+
pooled_prompt_embeds_scale);
144+
DiTForwardOutput out;
145+
out.tensors = output;
146+
return out;
147+
}
148+
149+
std::vector<torch::Tensor> forward_(
150+
torch::Tensor image,
151+
std::optional<torch::Tensor> prompt_embeds_opt,
152+
std::optional<torch::Tensor> pooled_prompt_embeds_opt,
153+
int64_t height = 384,
154+
int64_t width = 384,
155+
float prompt_embeds_scale = 1.0f,
156+
float pooled_prompt_embeds_scale = 1.0f) {
157+
torch::NoGradGuard no_grad;
158+
int64_t batch_size = image.dim() == 4 ? image.size(0) : 1;
159+
torch::Tensor image_latents =
160+
encode_image(image, /*num_images_per_prompt=*/1);
161+
torch::Tensor image_embeds =
162+
image_embedder_->forward(image_latents).to(options_);
163+
164+
// prompt_embeds: [batch_size, seq_len, hidden_dim]
165+
torch::Tensor prompt_embeds = prompt_embeds_opt.value_or(
166+
torch::zeros({batch_size, 512, 4096}, options_));
167+
// pooled_prompt_embeds: [batch_size, pooled_hidden_dim]
168+
torch::Tensor pooled_prompt_embeds = pooled_prompt_embeds_opt.value_or(
169+
torch::zeros({batch_size, 768}, options_));
170+
171+
prompt_embeds = torch::cat({prompt_embeds, image_embeds}, /*dim=*/1);
172+
prompt_embeds *= torch::full({batch_size}, prompt_embeds_scale, options_)
173+
.view({-1, 1, 1});
174+
pooled_prompt_embeds *=
175+
torch::full({batch_size}, pooled_prompt_embeds_scale, options_)
176+
.view({-1, 1});
177+
178+
prompt_embeds = torch::sum(prompt_embeds, /*dim=*/0, /*keepdim=*/true);
179+
pooled_prompt_embeds =
180+
torch::sum(pooled_prompt_embeds, /*dim=*/0, /*keepdim=*/true);
181+
182+
return {prompt_embeds, pooled_prompt_embeds};
183+
}
184+
185+
private:
186+
SiglipVisionModel image_encoder_{nullptr};
187+
std::unique_ptr<SiglipImageProcessor> feature_extractor_;
188+
ReduxImageEncoder image_embedder_{nullptr};
189+
};
190+
TORCH_MODULE(FluxPriorReduxPipeline);
191+
192+
REGISTER_DIT_MODEL(fluxredux, FluxPriorReduxPipeline);
193+
} // namespace xllm

0 commit comments

Comments
 (0)