-
Notifications
You must be signed in to change notification settings - Fork 113
feat: support FLUX.1-Redux-dev model on npu device. #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the FLUX.1-Redux-dev model on NPU devices. The changes are extensive, including a new pipeline implementation, a Siglip vision model, and a corresponding image processor. The overall structure looks good, but I've found a critical issue that could lead to a crash, along with several high and medium severity issues related to incorrect weight loading logic, dead code, and maintainability concerns such as magic numbers. Please address these points to improve the robustness and clarity of the new implementation.
| std::vector<torch::Tensor> output = forward_(image.value(), | ||
| prompt_embeds, | ||
| pooled_prompt_embeds, | ||
| generation_params.height, | ||
| generation_params.width, | ||
| prompt_embeds_scale, | ||
| pooled_prompt_embeds_scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code calls image.value() without first checking if image contains a value using image.has_value(). If input.images is not defined in DiTForwardInput, image will be std::nullopt, and calling .value() will result in a std::bad_optional_access exception, crashing the application. You should add a check before accessing the value.
if (!image.has_value()) {
LOG(ERROR) << "Image input is required for FluxPriorReduxPipeline.";
return {};
}
std::vector<torch::Tensor> output = forward_(image.value(),
prompt_embeds,
pooled_prompt_embeds,
generation_params.height,
generation_params.width,
prompt_embeds_scale,
pooled_prompt_embeds_scale);| if (image_preprocess_reader.contains("image_mean")) { | ||
| args_.mm_image_normalize_mean() = | ||
| image_prerocess_data["image_mean"].get<std::vector<double>>(); | ||
| } | ||
|
|
||
| if (image_preprocess_reader.contains("image_std")) { | ||
| args_.mm_image_normalize_std() = | ||
| image_prerocess_data["image_std"].get<std::vector<double>>(); | ||
| } | ||
|
|
||
| if (image_preprocess_reader.contains("norm_mean")) { | ||
| args_.mm_image_normalize_mean() = | ||
| image_prerocess_data["norm_mean"].get<std::vector<double>>(); | ||
| } | ||
|
|
||
| if (image_preprocess_reader.contains("norm_std")) { | ||
| args_.mm_image_normalize_std() = | ||
| image_prerocess_data["norm_std"].get<std::vector<double>>(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code checks for image_mean and then for norm_mean, with both writing to args_.mm_image_normalize_mean(). The same pattern exists for image_std and norm_std. If a config file contains both keys (e.g., image_mean and norm_mean), the value from the latter key will silently overwrite the former. This could lead to unexpected behavior. It would be safer and clearer to use an else if structure to prioritize one key over the other.
| if (image_preprocess_reader.contains("image_mean")) { | |
| args_.mm_image_normalize_mean() = | |
| image_prerocess_data["image_mean"].get<std::vector<double>>(); | |
| } | |
| if (image_preprocess_reader.contains("image_std")) { | |
| args_.mm_image_normalize_std() = | |
| image_prerocess_data["image_std"].get<std::vector<double>>(); | |
| } | |
| if (image_preprocess_reader.contains("norm_mean")) { | |
| args_.mm_image_normalize_mean() = | |
| image_prerocess_data["norm_mean"].get<std::vector<double>>(); | |
| } | |
| if (image_preprocess_reader.contains("norm_std")) { | |
| args_.mm_image_normalize_std() = | |
| image_prerocess_data["norm_std"].get<std::vector<double>>(); | |
| } | |
| if (image_preprocess_reader.contains("image_mean")) { | |
| args_.mm_image_normalize_mean() = | |
| image_prerocess_data["image_mean"].get<std::vector<double>>(); | |
| } else if (image_preprocess_reader.contains("norm_mean")) { | |
| args_.mm_image_normalize_mean() = | |
| image_prerocess_data["norm_mean"].get<std::vector<double>>(); | |
| } | |
| if (image_preprocess_reader.contains("image_std")) { | |
| args_.mm_image_normalize_std() = | |
| image_prerocess_data["image_std"].get<std::vector<double>>(); | |
| } else if (image_preprocess_reader.contains("norm_std")) { | |
| args_.mm_image_normalize_std() = | |
| image_prerocess_data["norm_std"].get<std::vector<double>>(); | |
| } |
| void load_model(std::unique_ptr<DiTFolderLoader> loader) { | ||
| for (const auto& state_dict : loader->get_state_dicts()) { | ||
| redux_up_->load_state_dict(state_dict->get_dict_with_prefix("redux_up.")); | ||
| redux_up_weight_loaded_ = true; | ||
| redux_up_bias_loaded_ = true; | ||
| redux_down_->load_state_dict( | ||
| state_dict->get_dict_with_prefix("redux_down.")); | ||
| redux_down_weight_loaded_ = true; | ||
| redux_down_bias_loaded_ = true; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_model function iterates through all state_dicts and loads weights for redux_up_ and redux_down_ in each iteration. If there are multiple state dict files, this will cause weights to be reloaded multiple times, which is inefficient. If weights for a single module are split across files, this logic is incorrect as it will only retain the weights from the last file. A more robust approach would be to iterate through the keys in each state dictionary and load them into the corresponding module parameters. A similar issue exists in SiglipVisionModelImpl::load_model.
| } | ||
|
|
||
| void load_state_dict(const StateDict& state_dict) { | ||
| const auto pos = state_dict.get_tensor("position_embedding.weight"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code attempts to load a tensor named "position_embedding.weight", but the parameter was registered as "position_embedding" on line 51. This mismatch will cause the weight loading to fail silently, as get_tensor will return an undefined tensor. The correct name should be used.
| const auto pos = state_dict.get_tensor("position_embedding.weight"); | |
| const auto pos = state_dict.get_tensor("position_embedding"); |
| void load_model(std::unique_ptr<DiTFolderLoader> loader) { | ||
| for (const auto& state_dict : loader->get_state_dicts()) { | ||
| transformer_->load_state_dict( | ||
| state_dict->get_dict_with_prefix("vision_model.")); | ||
| } | ||
|
|
||
| transformer_->verify_loaded_weights("vision_model."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_model function iterates through all state_dicts and loads weights for the transformer in each iteration. If there are multiple state dict files, this will cause weights to be reloaded multiple times, which is inefficient. If weights for a single module are split across files, this logic is incorrect as it will only retain the weights from the last file. A more robust approach would be to iterate through the keys in each state dictionary and load them into the corresponding module parameters.
|
|
||
| torch::Tensor forward(const torch::Tensor& pixel_values) { | ||
| int64_t batch_size = pixel_values.size(0); | ||
| int64_t seq_length = pixel_values.size(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| is_patch_embedding_loaded = true; | ||
| } | ||
|
|
||
| const auto bias = state_dict.get_tensor("patch_embedding.bias"); | ||
| if (bias.defined()) { | ||
| CHECK_EQ(patch_embedding_->bias.sizes(), bias.sizes()) | ||
| << "patch_embedding bias size mismatch for " << name(); | ||
| patch_embedding_->bias.data().copy_(bias); | ||
| is_patch_embedding_loaded = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for is_patch_embedding_loaded is flawed. It is set to true if the bias is loaded, even if the weight is not. This will cause verify_loaded_weights to pass incorrectly. Consider using separate flags for weight and bias (is_patch_embedding_weight_loaded and is_patch_embedding_bias_loaded) and check both in verify_loaded_weights.
| torch::Tensor forward(const torch::Tensor& embeddings) { | ||
| bool output_hidden_states = false; | ||
| bool output_attentions = false; | ||
| std::vector<torch::Tensor> encoder_states; | ||
|
|
||
| auto hidden_states = embeddings; | ||
| for (size_t i = 0; i < layers_.size(); ++i) { | ||
| encoder_states.emplace_back(hidden_states); | ||
| auto& layer = layers_[i]; | ||
| hidden_states = layer->forward(hidden_states); | ||
| } | ||
| std::vector<torch::Tensor> outputs = {hidden_states}; | ||
| return outputs[0]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In SiglipEncoderImpl::forward, the variables output_hidden_states and output_attentions are initialized to false and never used. The encoder_states vector is populated in the loop but its contents are never used or returned. This constitutes dead code and should be removed for clarity. The final return can also be simplified.
torch::Tensor forward(const torch::Tensor& embeddings) {
auto hidden_states = embeddings;
for (size_t i = 0; i < layers_.size(); ++i) {
auto& layer = layers_[i];
hidden_states = layer->forward(hidden_states);
}
return hidden_states;
}| } | ||
|
|
||
| private: | ||
| torch::nn::ModuleList blocks_ = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| std::vector<int64_t> SiglipImageProcessor::get_resize_output_image_size( | ||
| const torch::Tensor& image, | ||
| int shortest_edge) { | ||
| int height = image.size(1); | ||
| int width = image.size(2); | ||
|
|
||
| int short_size = std::min(height, width); | ||
| int long_size = std::max(height, width); | ||
|
|
||
| int new_short = shortest_edge; | ||
| int new_long = static_cast<int>(shortest_edge * | ||
| static_cast<float>(long_size) / short_size); | ||
|
|
||
| return height < width ? std::vector<int64_t>({new_short, new_long}) | ||
| : std::vector<int64_t>({new_long, new_short}); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR enables black-forest-labs/FLUX.1-Redux-dev to run on NPU devices.
The outputs prompt_embeds and pooled_prompt_embeds are not included yet and will be added in a later update.