@@ -660,6 +660,9 @@ struct SD3CLIPEmbedder : public Conditioner {
660660 std::shared_ptr<CLIPTextModelRunner> clip_l;
661661 std::shared_ptr<CLIPTextModelRunner> clip_g;
662662 std::shared_ptr<T5Runner> t5;
663+ bool use_clip_l = false ;
664+ bool use_clip_g = false ;
665+ bool use_t5 = false ;
663666
664667 SD3CLIPEmbedder (ggml_backend_t backend,
665668 std::map<std::string, enum ggml_type>& tensor_types,
@@ -668,38 +671,93 @@ struct SD3CLIPEmbedder : public Conditioner {
668671 if (clip_skip <= 0 ) {
669672 clip_skip = 2 ;
670673 }
671- clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
672- clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
673- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
674+
675+ for (auto pair : tensor_types) {
676+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
677+ use_clip_l = true ;
678+ } else if (pair.first .find (" text_encoders.clip_g" ) != std::string::npos) {
679+ use_clip_g = true ;
680+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
681+ use_t5 = true ;
682+ }
683+ }
684+ if (!use_clip_l && !use_clip_g && !use_t5) {
685+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
686+ return ;
687+ }
688+ if (use_clip_l) {
689+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
690+ } else {
691+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
692+ }
693+ if (use_clip_g) {
694+ clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
695+ } else {
696+ LOG_WARN (" clip_g text encoder not found! Prompt adherence might be degraded." );
697+ }
698+ if (use_t5) {
699+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
700+ } else {
701+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
702+ }
674703 }
675704
676705 void set_clip_skip (int clip_skip) {
677- clip_l->set_clip_skip (clip_skip);
678- clip_g->set_clip_skip (clip_skip);
706+ if (use_clip_l) {
707+ clip_l->set_clip_skip (clip_skip);
708+ }
709+ if (use_clip_g) {
710+ clip_g->set_clip_skip (clip_skip);
711+ }
679712 }
680713
681714 void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
682- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
683- clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
684- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
715+ if (use_clip_l) {
716+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
717+ }
718+ if (use_clip_g) {
719+ clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
720+ }
721+ if (use_t5) {
722+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
723+ }
685724 }
686725
687726 void alloc_params_buffer () {
688- clip_l->alloc_params_buffer ();
689- clip_g->alloc_params_buffer ();
690- t5->alloc_params_buffer ();
727+ if (use_clip_l) {
728+ clip_l->alloc_params_buffer ();
729+ }
730+ if (use_clip_g) {
731+ clip_g->alloc_params_buffer ();
732+ }
733+ if (use_t5) {
734+ t5->alloc_params_buffer ();
735+ }
691736 }
692737
693738 void free_params_buffer () {
694- clip_l->free_params_buffer ();
695- clip_g->free_params_buffer ();
696- t5->free_params_buffer ();
739+ if (use_clip_l) {
740+ clip_l->free_params_buffer ();
741+ }
742+ if (use_clip_g) {
743+ clip_g->free_params_buffer ();
744+ }
745+ if (use_t5) {
746+ t5->free_params_buffer ();
747+ }
697748 }
698749
699750 size_t get_params_buffer_size () {
700- size_t buffer_size = clip_l->get_params_buffer_size ();
701- buffer_size += clip_g->get_params_buffer_size ();
702- buffer_size += t5->get_params_buffer_size ();
751+ size_t buffer_size = 0 ;
752+ if (use_clip_l) {
753+ buffer_size += clip_l->get_params_buffer_size ();
754+ }
755+ if (use_clip_g) {
756+ buffer_size += clip_g->get_params_buffer_size ();
757+ }
758+ if (use_t5) {
759+ buffer_size += t5->get_params_buffer_size ();
760+ }
703761 return buffer_size;
704762 }
705763
@@ -731,23 +789,32 @@ struct SD3CLIPEmbedder : public Conditioner {
731789 for (const auto & item : parsed_attention) {
732790 const std::string& curr_text = item.first ;
733791 float curr_weight = item.second ;
734-
735- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
736- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
737- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
738-
739- curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
740- clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
741- clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
742-
743- curr_tokens = t5_tokenizer.Encode (curr_text, true );
744- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
745- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
792+ if (use_clip_l) {
793+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
794+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
795+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
796+ }
797+ if (use_clip_g) {
798+ std::vector<int > curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
799+ clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
800+ clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
801+ }
802+ if (use_t5) {
803+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
804+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
805+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
806+ }
746807 }
747808
748- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
749- clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
750- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
809+ if (use_clip_l) {
810+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
811+ }
812+ if (use_clip_g) {
813+ clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
814+ }
815+ if (use_t5) {
816+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
817+ }
751818
752819 // for (int i = 0; i < clip_l_tokens.size(); i++) {
753820 // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -792,10 +859,10 @@ struct SD3CLIPEmbedder : public Conditioner {
792859 std::vector<float > hidden_states_vec;
793860
794861 size_t chunk_len = 77 ;
795- size_t chunk_count = clip_l_tokens.size () / chunk_len;
862+ size_t chunk_count = std::max ( std::max ( clip_l_tokens.size (), clip_g_tokens. size ()), t5_tokens. size () ) / chunk_len;
796863 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
797864 // clip_l
798- {
865+ if (use_clip_l) {
799866 std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len,
800867 clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
801868 std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len,
@@ -840,10 +907,17 @@ struct SD3CLIPEmbedder : public Conditioner {
840907 &pooled_l,
841908 work_ctx);
842909 }
910+ } else {
911+ chunk_hidden_states_l = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 768 , chunk_len);
912+ ggml_set_f32 (chunk_hidden_states_l, 0 .f );
913+ if (chunk_idx == 0 ) {
914+ pooled_l = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
915+ ggml_set_f32 (pooled_l, 0 .f );
916+ }
843917 }
844918
845919 // clip_g
846- {
920+ if (use_clip_g) {
847921 std::vector<int > chunk_tokens (clip_g_tokens.begin () + chunk_idx * chunk_len,
848922 clip_g_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
849923 std::vector<float > chunk_weights (clip_g_weights.begin () + chunk_idx * chunk_len,
@@ -889,10 +963,17 @@ struct SD3CLIPEmbedder : public Conditioner {
889963 &pooled_g,
890964 work_ctx);
891965 }
966+ } else {
967+ chunk_hidden_states_g = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 1280 , chunk_len);
968+ ggml_set_f32 (chunk_hidden_states_g, 0 .f );
969+ if (chunk_idx == 0 ) {
970+ pooled_g = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 1280 );
971+ ggml_set_f32 (pooled_g, 0 .f );
972+ }
892973 }
893974
894975 // t5
895- {
976+ if (use_t5) {
896977 std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
897978 t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
898979 std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -919,6 +1000,8 @@ struct SD3CLIPEmbedder : public Conditioner {
9191000 float new_mean = ggml_tensor_mean (tensor);
9201001 ggml_tensor_scale (tensor, (original_mean / new_mean));
9211002 }
1003+ } else {
1004+ chunk_hidden_states_t5 = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
9221005 }
9231006
9241007 auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d (work_ctx,
@@ -961,11 +1044,19 @@ struct SD3CLIPEmbedder : public Conditioner {
9611044 ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
9621045 }
9631046
964- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
965- hidden_states = ggml_reshape_2d (work_ctx,
966- hidden_states,
967- chunk_hidden_states->ne [0 ],
968- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1047+ if (hidden_states_vec.size () > 0 ) {
1048+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1049+ hidden_states = ggml_reshape_2d (work_ctx,
1050+ hidden_states,
1051+ chunk_hidden_states->ne [0 ],
1052+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1053+ } else {
1054+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
1055+ }
1056+ if (pooled == NULL ) {
1057+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 2048 );
1058+ ggml_set_f32 (pooled, 0 .f );
1059+ }
9691060 return SDCondition (hidden_states, pooled, NULL );
9701061 }
9711062
0 commit comments