1111from cogkit .finetune import register
1212from cogkit .finetune .diffusion .schemas import DiffusionComponents
1313from cogkit .finetune .diffusion .trainer import DiffusionTrainer
14- from cogkit .finetune .utils import process_prompt_attention_mask , unwrap_model
14+ from cogkit .finetune .utils import (
15+ process_prompt_attention_mask ,
16+ unwrap_model ,
17+ replace_attn_processor ,
18+ )
1519from cogkit .utils import load_lora_checkpoint , unload_lora_checkpoint
1620from diffusers import (
1721 AutoencoderKL ,
1822 CogView4Pipeline ,
1923 CogView4Transformer2DModel ,
2024 FlowMatchEulerDiscreteScheduler ,
2125)
26+ from diffusers .models .transformers .transformer_cogview4 import CogView4TrainingAttnProcessor
2227
2328
2429class Cogview4Trainer (DiffusionTrainer ):
@@ -68,6 +73,7 @@ def load_components(self) -> DiffusionComponents:
6873 quantization_config = nf4_config ,
6974 device = self .accelerator .device ,
7075 )
76+ replace_attn_processor (components .transformer , CogView4TrainingAttnProcessor ())
7177
7278 ### vae
7379 components .vae = AutoencoderKL .from_pretrained (
@@ -98,6 +104,7 @@ def initialize_pipeline(self, ckpt_path: str | None = None) -> CogView4Pipeline:
98104 subfolder = "transformer" ,
99105 torch_dtype = self .state .weight_dtype ,
100106 )
107+ replace_attn_processor (transformer , CogView4TrainingAttnProcessor ())
101108 pipe = CogView4Pipeline (
102109 tokenizer = self .components .tokenizer ,
103110 text_encoder = self .components .text_encoder ,
@@ -170,7 +177,7 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
170177 - 'prompt_embedding': Tensor of shape [batch_size, sequence_length, embedding_dim]
171178 - 'image': List of image tensors (will be empty during validation)
172179 - 'encoded_image': Tensor of shape [batch_size, channels, height, width] (None during validation)
173- - 'attention_mask ': Dictionary with 'text_embedding_attn_mask' for transformer attention
180+ - 'text_attn_mask ': Tensor of shape [batch_size, sequence_length] for transformer attention
174181
175182 Note:
176183 This function assumes that all images in the batch have the same resolution.
@@ -180,7 +187,7 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
180187 "prompt_embedding" : [],
181188 "image" : [],
182189 "encoded_image" : [],
183- "attention_mask " : { "text_embedding_attn_mask" : None } ,
190+ "text_attn_mask " : None ,
184191 }
185192
186193 for sample in samples :
@@ -206,15 +213,12 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
206213 )
207214
208215 ret ["prompt_embedding" ] = prompt_embedding
209- ret ["attention_mask" ][ "text_embedding_attn_mask " ] = prompt_attention_mask
216+ ret ["text_attn_mask " ] = prompt_attention_mask
210217
211218 ret ["encoded_image" ] = torch .stack (ret ["encoded_image" ]) if ret ["encoded_image" ] else None
212219
213220 # shape of prompt_embedding: [batch_size, sequence_length, embedding_dim(4096)]
214- assert (
215- ret ["attention_mask" ]["text_embedding_attn_mask" ].shape
216- == ret ["prompt_embedding" ].shape [:2 ]
217- )
221+ assert ret ["text_attn_mask" ].shape == ret ["prompt_embedding" ].shape [:2 ]
218222
219223 return ret
220224
@@ -232,7 +236,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
232236 ) // (self .state .transformer_config .patch_size ** 2 )
233237 image_seq_len = torch .tensor ([image_seq_len ], device = self .accelerator .device )
234238
235- attention_mask = batch ["attention_mask " ]
239+ text_attn_mask = batch ["text_attn_mask " ]
236240
237241 num_train_timesteps = self .components .scheduler .config .num_train_timesteps
238242 sigmas = self .get_sigmas (batch_size , image_seq_len )
@@ -263,7 +267,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
263267 target_size = target_size ,
264268 crop_coords = crop_coords ,
265269 return_dict = False ,
266- attention_mask = attention_mask ,
270+ attention_kwargs = { "text_attn_mask" : text_attn_mask } ,
267271 )[0 ]
268272
269273 loss = torch .mean ((noise_pred_cond - model_label ) ** 2 , dim = (1 , 2 , 3 ))
0 commit comments