Skip to content

Commit c13cf04

Browse files
committed
conditionner: make text encoders optional for SD3.x
1 parent 10c6501 commit c13cf04

File tree

1 file changed

+132
-41
lines changed

1 file changed

+132
-41
lines changed

conditioner.hpp

Lines changed: 132 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,9 @@ struct SD3CLIPEmbedder : public Conditioner {
660660
std::shared_ptr<CLIPTextModelRunner> clip_l;
661661
std::shared_ptr<CLIPTextModelRunner> clip_g;
662662
std::shared_ptr<T5Runner> t5;
663+
bool use_clip_l = false;
664+
bool use_clip_g = false;
665+
bool use_t5 = false;
663666

664667
SD3CLIPEmbedder(ggml_backend_t backend,
665668
std::map<std::string, enum ggml_type>& tensor_types,
@@ -668,38 +671,93 @@ struct SD3CLIPEmbedder : public Conditioner {
668671
if (clip_skip <= 0) {
669672
clip_skip = 2;
670673
}
671-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
672-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
673-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
674+
675+
for (auto pair : tensor_types) {
676+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
677+
use_clip_l = true;
678+
} else if (pair.first.find("text_encoders.clip_g") != std::string::npos) {
679+
use_clip_g = true;
680+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
681+
use_t5 = true;
682+
}
683+
}
684+
if (!use_clip_l && !use_clip_g && !use_t5) {
685+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
686+
return;
687+
}
688+
if (use_clip_l) {
689+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
690+
} else {
691+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
692+
}
693+
if (use_clip_g) {
694+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
695+
} else {
696+
LOG_WARN("clip_g text encoder not found! Prompt adherence might be degraded.");
697+
}
698+
if (use_t5) {
699+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
700+
} else {
701+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
702+
}
674703
}
675704

676705
void set_clip_skip(int clip_skip) {
677-
clip_l->set_clip_skip(clip_skip);
678-
clip_g->set_clip_skip(clip_skip);
706+
if (use_clip_l) {
707+
clip_l->set_clip_skip(clip_skip);
708+
}
709+
if (use_clip_g) {
710+
clip_g->set_clip_skip(clip_skip);
711+
}
679712
}
680713

681714
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
682-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
683-
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
684-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
715+
if (use_clip_l) {
716+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
717+
}
718+
if (use_clip_g) {
719+
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
720+
}
721+
if (use_t5) {
722+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
723+
}
685724
}
686725

687726
void alloc_params_buffer() {
688-
clip_l->alloc_params_buffer();
689-
clip_g->alloc_params_buffer();
690-
t5->alloc_params_buffer();
727+
if (use_clip_l) {
728+
clip_l->alloc_params_buffer();
729+
}
730+
if (use_clip_g) {
731+
clip_g->alloc_params_buffer();
732+
}
733+
if (use_t5) {
734+
t5->alloc_params_buffer();
735+
}
691736
}
692737

693738
void free_params_buffer() {
694-
clip_l->free_params_buffer();
695-
clip_g->free_params_buffer();
696-
t5->free_params_buffer();
739+
if (use_clip_l) {
740+
clip_l->free_params_buffer();
741+
}
742+
if (use_clip_g) {
743+
clip_g->free_params_buffer();
744+
}
745+
if (use_t5) {
746+
t5->free_params_buffer();
747+
}
697748
}
698749

699750
size_t get_params_buffer_size() {
700-
size_t buffer_size = clip_l->get_params_buffer_size();
701-
buffer_size += clip_g->get_params_buffer_size();
702-
buffer_size += t5->get_params_buffer_size();
751+
size_t buffer_size = 0;
752+
if (use_clip_l) {
753+
buffer_size += clip_l->get_params_buffer_size();
754+
}
755+
if (use_clip_g) {
756+
buffer_size += clip_g->get_params_buffer_size();
757+
}
758+
if (use_t5) {
759+
buffer_size += t5->get_params_buffer_size();
760+
}
703761
return buffer_size;
704762
}
705763

@@ -731,23 +789,32 @@ struct SD3CLIPEmbedder : public Conditioner {
731789
for (const auto& item : parsed_attention) {
732790
const std::string& curr_text = item.first;
733791
float curr_weight = item.second;
734-
735-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
736-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
737-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
738-
739-
curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
740-
clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
741-
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
742-
743-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
744-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
745-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
792+
if (use_clip_l) {
793+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
794+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
795+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
796+
}
797+
if (use_clip_g) {
798+
std::vector<int> curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
799+
clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
800+
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
801+
}
802+
if (use_t5) {
803+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
804+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
805+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
806+
}
746807
}
747808

748-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
749-
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
750-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
809+
if (use_clip_l) {
810+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
811+
}
812+
if (use_clip_g) {
813+
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
814+
}
815+
if (use_t5) {
816+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
817+
}
751818

752819
// for (int i = 0; i < clip_l_tokens.size(); i++) {
753820
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -792,10 +859,10 @@ struct SD3CLIPEmbedder : public Conditioner {
792859
std::vector<float> hidden_states_vec;
793860

794861
size_t chunk_len = 77;
795-
size_t chunk_count = clip_l_tokens.size() / chunk_len;
862+
size_t chunk_count = std::max(std::max(clip_l_tokens.size(), clip_g_tokens.size()), t5_tokens.size()) / chunk_len;
796863
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
797864
// clip_l
798-
{
865+
if (use_clip_l) {
799866
std::vector<int> chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len,
800867
clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len);
801868
std::vector<float> chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len,
@@ -840,10 +907,17 @@ struct SD3CLIPEmbedder : public Conditioner {
840907
&pooled_l,
841908
work_ctx);
842909
}
910+
} else {
911+
chunk_hidden_states_l = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, chunk_len);
912+
ggml_set_f32(chunk_hidden_states_l, 0.f);
913+
if (chunk_idx == 0) {
914+
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
915+
ggml_set_f32(pooled_l, 0.f);
916+
}
843917
}
844918

845919
// clip_g
846-
{
920+
if (use_clip_g) {
847921
std::vector<int> chunk_tokens(clip_g_tokens.begin() + chunk_idx * chunk_len,
848922
clip_g_tokens.begin() + (chunk_idx + 1) * chunk_len);
849923
std::vector<float> chunk_weights(clip_g_weights.begin() + chunk_idx * chunk_len,
@@ -889,10 +963,17 @@ struct SD3CLIPEmbedder : public Conditioner {
889963
&pooled_g,
890964
work_ctx);
891965
}
966+
} else {
967+
chunk_hidden_states_g = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 1280, chunk_len);
968+
ggml_set_f32(chunk_hidden_states_g, 0.f);
969+
if (chunk_idx == 0) {
970+
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
971+
ggml_set_f32(pooled_g, 0.f);
972+
}
892973
}
893974

894975
// t5
895-
{
976+
if (use_t5) {
896977
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
897978
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
898979
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -919,6 +1000,8 @@ struct SD3CLIPEmbedder : public Conditioner {
9191000
float new_mean = ggml_tensor_mean(tensor);
9201001
ggml_tensor_scale(tensor, (original_mean / new_mean));
9211002
}
1003+
} else {
1004+
chunk_hidden_states_t5 = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 0);
9221005
}
9231006

9241007
auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d(work_ctx,
@@ -961,11 +1044,19 @@ struct SD3CLIPEmbedder : public Conditioner {
9611044
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
9621045
}
9631046

964-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
965-
hidden_states = ggml_reshape_2d(work_ctx,
966-
hidden_states,
967-
chunk_hidden_states->ne[0],
968-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1047+
if (hidden_states_vec.size() > 0) {
1048+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1049+
hidden_states = ggml_reshape_2d(work_ctx,
1050+
hidden_states,
1051+
chunk_hidden_states->ne[0],
1052+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1053+
} else {
1054+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 0);
1055+
}
1056+
if (pooled == NULL) {
1057+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 2048);
1058+
ggml_set_f32(pooled, 0.f);
1059+
}
9691060
return SDCondition(hidden_states, pooled, NULL);
9701061
}
9711062

0 commit comments

Comments
 (0)