Skip to content

Conversation

@z-jun03
Copy link
Collaborator

@z-jun03 z-jun03 commented Dec 24, 2025

This PR enables black-forest-labs/FLUX.1-Redux-dev to run on NPU devices.
The outputs prompt_embeds and pooled_prompt_embeds are not included yet and will be added in a later update.

@XuZhang99
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the FLUX.1-Redux-dev model on NPU devices. The changes are extensive, including a new pipeline implementation, a Siglip vision model, and a corresponding image processor. The overall structure looks good, but I've found a critical issue that could lead to a crash, along with several high and medium severity issues related to incorrect weight loading logic, dead code, and maintainability concerns such as magic numbers. Please address these points to improve the robustness and clarity of the new implementation.

Comment on lines +137 to +143
std::vector<torch::Tensor> output = forward_(image.value(),
prompt_embeds,
pooled_prompt_embeds,
generation_params.height,
generation_params.width,
prompt_embeds_scale,
pooled_prompt_embeds_scale);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code calls image.value() without first checking if image contains a value using image.has_value(). If input.images is not defined in DiTForwardInput, image will be std::nullopt, and calling .value() will result in a std::bad_optional_access exception, crashing the application. You should add a check before accessing the value.

    if (!image.has_value()) {
      LOG(ERROR) << "Image input is required for FluxPriorReduxPipeline.";
      return {};
    }
    std::vector<torch::Tensor> output = forward_(image.value(),
                                                 prompt_embeds,
                                                 pooled_prompt_embeds,
                                                 generation_params.height,
                                                 generation_params.width,
                                                 prompt_embeds_scale,
                                                 pooled_prompt_embeds_scale);

Comment on lines +266 to +284
if (image_preprocess_reader.contains("image_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["image_mean"].get<std::vector<double>>();
}

if (image_preprocess_reader.contains("image_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["image_std"].get<std::vector<double>>();
}

if (image_preprocess_reader.contains("norm_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["norm_mean"].get<std::vector<double>>();
}

if (image_preprocess_reader.contains("norm_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["norm_std"].get<std::vector<double>>();
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code checks for image_mean and then for norm_mean, with both writing to args_.mm_image_normalize_mean(). The same pattern exists for image_std and norm_std. If a config file contains both keys (e.g., image_mean and norm_mean), the value from the latter key will silently overwrite the former. This could lead to unexpected behavior. It would be safer and clearer to use an else if structure to prioritize one key over the other.

Suggested change
if (image_preprocess_reader.contains("image_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["image_mean"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("image_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["image_std"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("norm_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["norm_mean"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("norm_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["norm_std"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("image_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["image_mean"].get<std::vector<double>>();
} else if (image_preprocess_reader.contains("norm_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["norm_mean"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("image_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["image_std"].get<std::vector<double>>();
} else if (image_preprocess_reader.contains("norm_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["norm_std"].get<std::vector<double>>();
}

Comment on lines +50 to +59
void load_model(std::unique_ptr<DiTFolderLoader> loader) {
for (const auto& state_dict : loader->get_state_dicts()) {
redux_up_->load_state_dict(state_dict->get_dict_with_prefix("redux_up."));
redux_up_weight_loaded_ = true;
redux_up_bias_loaded_ = true;
redux_down_->load_state_dict(
state_dict->get_dict_with_prefix("redux_down."));
redux_down_weight_loaded_ = true;
redux_down_bias_loaded_ = true;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The load_model function iterates through all state_dicts and loads weights for redux_up_ and redux_down_ in each iteration. If there are multiple state dict files, this will cause weights to be reloaded multiple times, which is inefficient. If weights for a single module are split across files, this logic is incorrect as it will only retain the weights from the last file. A more robust approach would be to iterate through the keys in each state dictionary and load them into the corresponding module parameters. A similar issue exists in SiglipVisionModelImpl::load_model.

}

void load_state_dict(const StateDict& state_dict) {
const auto pos = state_dict.get_tensor("position_embedding.weight");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code attempts to load a tensor named "position_embedding.weight", but the parameter was registered as "position_embedding" on line 51. This mismatch will cause the weight loading to fail silently, as get_tensor will return an undefined tensor. The correct name should be used.

Suggested change
const auto pos = state_dict.get_tensor("position_embedding.weight");
const auto pos = state_dict.get_tensor("position_embedding");

Comment on lines +498 to +505
void load_model(std::unique_ptr<DiTFolderLoader> loader) {
for (const auto& state_dict : loader->get_state_dicts()) {
transformer_->load_state_dict(
state_dict->get_dict_with_prefix("vision_model."));
}

transformer_->verify_loaded_weights("vision_model.");
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The load_model function iterates through all state_dicts and loads weights for the transformer in each iteration. If there are multiple state dict files, this will cause weights to be reloaded multiple times, which is inefficient. If weights for a single module are split across files, this logic is incorrect as it will only retain the weights from the last file. A more robust approach would be to iterate through the keys in each state dictionary and load them into the corresponding module parameters.


torch::Tensor forward(const torch::Tensor& pixel_values) {
int64_t batch_size = pixel_values.size(0);
int64_t seq_length = pixel_values.size(1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable seq_length is calculated but never used within the SiglipVisionEmbeddingsImpl::forward function. It should be removed to improve code clarity.

Comment on lines +82 to +90
is_patch_embedding_loaded = true;
}

const auto bias = state_dict.get_tensor("patch_embedding.bias");
if (bias.defined()) {
CHECK_EQ(patch_embedding_->bias.sizes(), bias.sizes())
<< "patch_embedding bias size mismatch for " << name();
patch_embedding_->bias.data().copy_(bias);
is_patch_embedding_loaded = true;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for is_patch_embedding_loaded is flawed. It is set to true if the bias is loaded, even if the weight is not. This will cause verify_loaded_weights to pass incorrectly. Consider using separate flags for weight and bias (is_patch_embedding_weight_loaded and is_patch_embedding_bias_loaded) and check both in verify_loaded_weights.

Comment on lines +386 to +399
torch::Tensor forward(const torch::Tensor& embeddings) {
bool output_hidden_states = false;
bool output_attentions = false;
std::vector<torch::Tensor> encoder_states;

auto hidden_states = embeddings;
for (size_t i = 0; i < layers_.size(); ++i) {
encoder_states.emplace_back(hidden_states);
auto& layer = layers_[i];
hidden_states = layer->forward(hidden_states);
}
std::vector<torch::Tensor> outputs = {hidden_states};
return outputs[0];
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In SiglipEncoderImpl::forward, the variables output_hidden_states and output_attentions are initialized to false and never used. The encoder_states vector is populated in the loop but its contents are never used or returned. This constitutes dead code and should be removed for clarity. The final return can also be simplified.

  torch::Tensor forward(const torch::Tensor& embeddings) {
    auto hidden_states = embeddings;
    for (size_t i = 0; i < layers_.size(); ++i) {
      auto& layer = layers_[i];
      hidden_states = layer->forward(hidden_states);
    }
    return hidden_states;
  }

}

private:
torch::nn::ModuleList blocks_ = nullptr;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The member variable blocks_ is declared in SiglipEncoderImpl but is never initialized or used. The layers_ member is used instead for storing the encoder layers. This unused variable should be removed to clean up the code.

Comment on lines +68 to +83
std::vector<int64_t> SiglipImageProcessor::get_resize_output_image_size(
const torch::Tensor& image,
int shortest_edge) {
int height = image.size(1);
int width = image.size(2);

int short_size = std::min(height, width);
int long_size = std::max(height, width);

int new_short = shortest_edge;
int new_long = static_cast<int>(shortest_edge *
static_cast<float>(long_size) / short_size);

return height < width ? std::vector<int64_t>({new_short, new_long})
: std::vector<int64_t>({new_long, new_short});
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The private method get_resize_output_image_size is defined but it is not called anywhere within the SiglipImageProcessor class. This appears to be dead code and should be removed if it's not intended for future use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants