Skip to content

Commit 1bfef3f

Browse files
committed
[chore] Align with diffusers
1 parent e4f6da0 commit 1bfef3f

File tree

3 files changed

+35
-24
lines changed

3 files changed

+35
-24
lines changed

src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,19 @@
1111
from cogkit.finetune import register
1212
from cogkit.finetune.diffusion.schemas import DiffusionComponents
1313
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
14-
from cogkit.finetune.utils import process_prompt_attention_mask, unwrap_model
14+
from cogkit.finetune.utils import (
15+
process_prompt_attention_mask,
16+
unwrap_model,
17+
replace_attn_processor,
18+
)
1519
from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint
1620
from diffusers import (
1721
AutoencoderKL,
1822
CogView4Pipeline,
1923
CogView4Transformer2DModel,
2024
FlowMatchEulerDiscreteScheduler,
2125
)
26+
from diffusers.models.transformers.transformer_cogview4 import CogView4TrainingAttnProcessor
2227

2328

2429
class Cogview4Trainer(DiffusionTrainer):
@@ -68,6 +73,7 @@ def load_components(self) -> DiffusionComponents:
6873
quantization_config=nf4_config,
6974
device=self.accelerator.device,
7075
)
76+
replace_attn_processor(components.transformer, CogView4TrainingAttnProcessor())
7177

7278
### vae
7379
components.vae = AutoencoderKL.from_pretrained(
@@ -98,6 +104,7 @@ def initialize_pipeline(self, ckpt_path: str | None = None) -> CogView4Pipeline:
98104
subfolder="transformer",
99105
torch_dtype=self.state.weight_dtype,
100106
)
107+
replace_attn_processor(transformer, CogView4TrainingAttnProcessor())
101108
pipe = CogView4Pipeline(
102109
tokenizer=self.components.tokenizer,
103110
text_encoder=self.components.text_encoder,
@@ -170,7 +177,7 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
170177
- 'prompt_embedding': Tensor of shape [batch_size, sequence_length, embedding_dim]
171178
- 'image': List of image tensors (will be empty during validation)
172179
- 'encoded_image': Tensor of shape [batch_size, channels, height, width] (None during validation)
173-
- 'attention_mask': Dictionary with 'text_embedding_attn_mask' for transformer attention
180+
- 'text_attn_mask': Tensor of shape [batch_size, sequence_length] for transformer attention
174181
175182
Note:
176183
This function assumes that all images in the batch have the same resolution.
@@ -180,7 +187,7 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
180187
"prompt_embedding": [],
181188
"image": [],
182189
"encoded_image": [],
183-
"attention_mask": {"text_embedding_attn_mask": None},
190+
"text_attn_mask": None,
184191
}
185192

186193
for sample in samples:
@@ -206,15 +213,12 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
206213
)
207214

208215
ret["prompt_embedding"] = prompt_embedding
209-
ret["attention_mask"]["text_embedding_attn_mask"] = prompt_attention_mask
216+
ret["text_attn_mask"] = prompt_attention_mask
210217

211218
ret["encoded_image"] = torch.stack(ret["encoded_image"]) if ret["encoded_image"] else None
212219

213220
# shape of prompt_embedding: [batch_size, sequence_length, embedding_dim(4096)]
214-
assert (
215-
ret["attention_mask"]["text_embedding_attn_mask"].shape
216-
== ret["prompt_embedding"].shape[:2]
217-
)
221+
assert ret["text_attn_mask"].shape == ret["prompt_embedding"].shape[:2]
218222

219223
return ret
220224

@@ -232,7 +236,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
232236
) // (self.state.transformer_config.patch_size**2)
233237
image_seq_len = torch.tensor([image_seq_len], device=self.accelerator.device)
234238

235-
attention_mask = batch["attention_mask"]
239+
text_attn_mask = batch["text_attn_mask"]
236240

237241
num_train_timesteps = self.components.scheduler.config.num_train_timesteps
238242
sigmas = self.get_sigmas(batch_size, image_seq_len)
@@ -263,7 +267,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
263267
target_size=target_size,
264268
crop_coords=crop_coords,
265269
return_dict=False,
266-
attention_mask=attention_mask,
270+
attention_kwargs={"text_attn_mask": text_attn_mask},
267271
)[0]
268272

269273
loss = torch.mean((noise_pred_cond - model_label) ** 2, dim=(1, 2, 3))

src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def collate_fn_packing(self, samples: list[dict[str, list[Any]]]) -> dict[str, A
9999
- prompt_embedding: Batched prompt embeddings
100100
- encoded_image: Batched encoded image latents
101101
- image_rotary_emb: Rotary embeddings for images
102-
- attention_mask: Dictionary containing:
102+
- attention_kwargs: Dictionary containing:
103103
- batch_flag: Indices indicating which sample each item belongs to
104-
- text_embedding_attn_mask: Attention mask for text embeddings
105-
- latent_embedding_attn_mask: Attention mask for latent embeddings
104+
- text_attn_mask: Attention mask for text embeddings
105+
- latent_attn_mask: Attention mask for latent embeddings
106106
- pixel_mask: Mask for valid pixel regions
107107
- original_size: Original dimensions of the images
108108
@@ -114,10 +114,10 @@ def collate_fn_packing(self, samples: list[dict[str, list[Any]]]) -> dict[str, A
114114
"prompt_embedding": None,
115115
"encoded_image": None,
116116
"image_rotary_emb": None,
117-
"attention_mask": {
117+
"attention_kwargs": {
118118
"batch_flag": None,
119-
"text_embedding_attn_mask": None,
120-
"latent_embedding_attn_mask": None,
119+
"text_attn_mask": None,
120+
"latent_attn_mask": None,
121121
},
122122
"pixel_mask": None,
123123
"original_size": None,
@@ -144,15 +144,15 @@ def collate_fn_packing(self, samples: list[dict[str, list[Any]]]) -> dict[str, A
144144

145145
# Store in batched_data
146146
batched_data["prompt_embedding"] = prompt_embedding
147-
batched_data["attention_mask"]["text_embedding_attn_mask"] = prompt_attention_mask
147+
batched_data["attention_kwargs"]["text_attn_mask"] = prompt_attention_mask
148148
batched_data["encoded_image"] = padded_latent
149149
batched_data["image_rotary_emb"] = image_rotary_emb
150-
batched_data["attention_mask"]["latent_embedding_attn_mask"] = (
151-
vtoken_attention_mask.reshape(len(batch_flag), -1)
150+
batched_data["attention_kwargs"]["latent_attn_mask"] = vtoken_attention_mask.reshape(
151+
len(batch_flag), -1
152152
)
153153
batched_data["pixel_mask"] = pixel_mask
154154

155-
batched_data["attention_mask"]["batch_flag"] = batch_flag
155+
batched_data["attention_kwargs"]["batch_flag"] = batch_flag
156156
batched_data["original_size"] = torch.tensor(
157157
[(img.height, img.width) for img in samples["image"]]
158158
)
@@ -168,8 +168,8 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
168168
batch_size, text_seqlen, text_embedding_dim = prompt_embeds.shape
169169
batch_size, num_channels, height, width = latent.shape
170170

171-
attn_mask = batch["attention_mask"]
172-
latent_attention_mask = attn_mask["latent_embedding_attn_mask"].float()
171+
attention_kwargs = batch["attention_kwargs"]
172+
latent_attention_mask = attention_kwargs["latent_attn_mask"].float()
173173
assert latent_attention_mask.dim() == 2
174174
vtoken_seq_len = torch.sum(latent_attention_mask != 0, dim=1)
175175

@@ -196,8 +196,8 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
196196
target_size=target_size,
197197
crop_coords=crop_coords,
198198
return_dict=False,
199-
attention_mask=attn_mask,
200199
image_rotary_emb=image_rotary_emb,
200+
attention_kwargs=attention_kwargs,
201201
)[0]
202202

203203
pixel_mask = batch["pixel_mask"]

src/cogkit/finetune/utils/attn_mask.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import math
2-
from typing import List, Tuple
2+
from typing import Any, List, Tuple
33

44
import torch
55
from transformers import AutoTokenizer
6+
from diffusers.models.attention_processor import Attention
67

78
from .filters import MeanFilter
89

@@ -124,3 +125,9 @@ def process_latent_attention_mask(
124125
mask_assert(vtoken_attention_mask)
125126

126127
return padded_latent, vtoken_attention_mask, pixel_mask
128+
129+
130+
def replace_attn_processor(model: torch.nn.Module, attn_processor_obj: Any) -> None:
131+
for name, submodule in model.named_modules():
132+
if isinstance(submodule, Attention):
133+
submodule.processor = attn_processor_obj

0 commit comments

Comments
 (0)