@@ -1096,38 +1096,82 @@ struct FluxCLIPEmbedder : public Conditioner {
10961096 std::shared_ptr<CLIPTextModelRunner> clip_l;
10971097 std::shared_ptr<T5Runner> t5;
10981098
1099+ bool use_clip_l = false ;
1100+ bool use_t5 = false ;
1101+
10991102 FluxCLIPEmbedder (ggml_backend_t backend,
11001103 std::map<std::string, enum ggml_type>& tensor_types,
11011104 int clip_skip = -1 ) {
11021105 if (clip_skip <= 0 ) {
11031106 clip_skip = 2 ;
11041107 }
1105- clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, true );
1106- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1108+
1109+ for (auto pair : tensor_types) {
1110+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
1111+ use_clip_l = true ;
1112+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
1113+ use_t5 = true ;
1114+ }
1115+ }
1116+
1117+ if (!use_clip_l && !use_t5) {
1118+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
1119+ return ;
1120+ }
1121+
1122+ if (use_clip_l) {
1123+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, true );
1124+ } else {
1125+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
1126+ }
1127+ if (use_t5) {
1128+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1129+ } else {
1130+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
1131+ }
11071132 }
11081133
11091134 void set_clip_skip (int clip_skip) {
1110- clip_l->set_clip_skip (clip_skip);
1135+ if (use_clip_l) {
1136+ clip_l->set_clip_skip (clip_skip);
1137+ }
11111138 }
11121139
11131140 void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1114- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1115- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1141+ if (use_clip_l) {
1142+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1143+ }
1144+ if (use_t5) {
1145+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1146+ }
11161147 }
11171148
11181149 void alloc_params_buffer () {
1119- clip_l->alloc_params_buffer ();
1120- t5->alloc_params_buffer ();
1150+ if (use_clip_l) {
1151+ clip_l->alloc_params_buffer ();
1152+ }
1153+ if (use_t5) {
1154+ t5->alloc_params_buffer ();
1155+ }
11211156 }
11221157
11231158 void free_params_buffer () {
1124- clip_l->free_params_buffer ();
1125- t5->free_params_buffer ();
1159+ if (use_clip_l) {
1160+ clip_l->free_params_buffer ();
1161+ }
1162+ if (use_t5) {
1163+ t5->free_params_buffer ();
1164+ }
11261165 }
11271166
11281167 size_t get_params_buffer_size () {
1129- size_t buffer_size = clip_l->get_params_buffer_size ();
1130- buffer_size += t5->get_params_buffer_size ();
1168+ size_t buffer_size = 0 ;
1169+ if (use_clip_l) {
1170+ buffer_size += clip_l->get_params_buffer_size ();
1171+ }
1172+ if (use_t5) {
1173+ buffer_size += t5->get_params_buffer_size ();
1174+ }
11311175 return buffer_size;
11321176 }
11331177
@@ -1157,18 +1201,23 @@ struct FluxCLIPEmbedder : public Conditioner {
11571201 for (const auto & item : parsed_attention) {
11581202 const std::string& curr_text = item.first ;
11591203 float curr_weight = item.second ;
1160-
1161- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1162- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1163- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1164-
1165- curr_tokens = t5_tokenizer.Encode (curr_text, true );
1166- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1167- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1204+ if (use_clip_l) {
1205+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1206+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1207+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1208+ }
1209+ if (use_t5) {
1210+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1211+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1212+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1213+ }
1214+ }
1215+ if (use_clip_l) {
1216+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1217+ }
1218+ if (use_t5) {
1219+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
11681220 }
1169-
1170- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1171- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
11721221
11731222 // for (int i = 0; i < clip_l_tokens.size(); i++) {
11741223 // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1201,34 +1250,36 @@ struct FluxCLIPEmbedder : public Conditioner {
12011250 std::vector<float > hidden_states_vec;
12021251
12031252 size_t chunk_len = 256 ;
1204- size_t chunk_count = t5_tokens.size () / chunk_len;
1253+ size_t chunk_count = std::max (clip_l_tokens. size () > 0 ? chunk_len : 0 , t5_tokens.size () ) / chunk_len;
12051254 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
12061255 // clip_l
12071256 if (chunk_idx == 0 ) {
1208- size_t chunk_len_l = 77 ;
1209- std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1210- clip_l_tokens.begin () + chunk_len_l);
1211- std::vector<float > chunk_weights (clip_l_weights.begin (),
1212- clip_l_weights.begin () + chunk_len_l);
1257+ if (use_clip_l) {
1258+ size_t chunk_len_l = 77 ;
1259+ std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1260+ clip_l_tokens.begin () + chunk_len_l);
1261+ std::vector<float > chunk_weights (clip_l_weights.begin (),
1262+ clip_l_weights.begin () + chunk_len_l);
12131263
1214- auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1215- size_t max_token_idx = 0 ;
1264+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1265+ size_t max_token_idx = 0 ;
12161266
1217- auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1218- max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
1267+ auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1268+ max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
12191269
1220- clip_l->compute (n_threads,
1221- input_ids,
1222- 0 ,
1223- NULL ,
1224- max_token_idx,
1225- true ,
1226- &pooled,
1227- work_ctx);
1270+ clip_l->compute (n_threads,
1271+ input_ids,
1272+ 0 ,
1273+ NULL ,
1274+ max_token_idx,
1275+ true ,
1276+ &pooled,
1277+ work_ctx);
1278+ }
12281279 }
12291280
12301281 // t5
1231- {
1282+ if (use_t5) {
12321283 std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
12331284 t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
12341285 std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -1255,8 +1306,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12551306 float new_mean = ggml_tensor_mean (tensor);
12561307 ggml_tensor_scale (tensor, (original_mean / new_mean));
12571308 }
1309+ } else {
1310+ chunk_hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , chunk_len);
1311+ ggml_set_f32 (chunk_hidden_states, 0 .f );
12581312 }
12591313
1314+
12601315 int64_t t1 = ggml_time_ms ();
12611316 LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
12621317 if (force_zero_embeddings) {
@@ -1265,17 +1320,26 @@ struct FluxCLIPEmbedder : public Conditioner {
12651320 vec[i] = 0 ;
12661321 }
12671322 }
1268-
1323+
12691324 hidden_states_vec.insert (hidden_states_vec.end (),
1270- (float *)chunk_hidden_states->data ,
1271- ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1325+ (float *)chunk_hidden_states->data ,
1326+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1327+ }
1328+
1329+ if (hidden_states_vec.size () > 0 ) {
1330+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1331+ hidden_states = ggml_reshape_2d (work_ctx,
1332+ hidden_states,
1333+ chunk_hidden_states->ne [0 ],
1334+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1335+ } else {
1336+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 256 );
1337+ ggml_set_f32 (hidden_states, 0 .f );
1338+ }
1339+ if (pooled == NULL ) {
1340+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
1341+ ggml_set_f32 (pooled, 0 .f );
12721342 }
1273-
1274- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1275- hidden_states = ggml_reshape_2d (work_ctx,
1276- hidden_states,
1277- chunk_hidden_states->ne [0 ],
1278- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
12791343 return SDCondition (hidden_states, pooled, NULL );
12801344 }
12811345
0 commit comments