Skip to content

fix pytorch_lightning package problem #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7eb7d5a
Add DreamBooth-AltDiffusion pipeline and update training configurations
whybfq May 31, 2025
51b3d91
Add DreamBooth-AltDiffusion pipeline and update training configurations
whybfq May 31, 2025
9e52a8d
Add DreamBooth-AltDiffusion pipeline and update training configurations
whybfq May 31, 2025
d112fbd
dreambooth, OOM
whybfq May 31, 2025
fd0fc96
the most import change to dreambooth.py
whybfq Jun 1, 2025
bb0bf09
Update requirements.txt
whybfq Jun 5, 2025
5d7bc05
Update dreambooth.py
whybfq Jun 5, 2025
fc18d13
enter FlagAI/examples/AltDiffusion
whybfq Jun 5, 2025
c83b0cc
Update dreambooth.py
whybfq Jun 5, 2025
f6e1931
Update dreambooth.py
whybfq Jun 5, 2025
050feef
Update generate.py
whybfq Jun 5, 2025
24f5f0c
Update generate.py
whybfq Jun 5, 2025
acccd88
Update generate.py
whybfq Jun 5, 2025
7cad87e
Update generate.py
whybfq Jun 7, 2025
df41efc
add comments
whybfq Jun 7, 2025
9bd7abe
del samples
whybfq Jun 8, 2025
79b9c85
change generate
whybfq Jun 8, 2025
507deac
Delete examples/AltDiffusion/AltDiffusionOutputs directory
whybfq Jun 8, 2025
e0f2682
Delete examples/AltDiffusion-m18/AltDiffusionOutputs directory
whybfq Jun 8, 2025
1448b8a
keep new model in different place
whybfq Jun 8, 2025
2b182e3
Merge remote-tracking branch 'origin/master'
whybfq Jun 8, 2025
504e3a9
keep new model in different place
whybfq Jun 8, 2025
e2d4a9c
keep new model in different place
whybfq Jun 8, 2025
99365e0
Kaggle Notebook | AltDiffusion | BasicVersion
whybfq Jun 8, 2025
5ce3c3e
keep new model in different place
whybfq Jun 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
503 changes: 503 additions & 0 deletions altdiffusionTest.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions examples/AltCLIP-m18/altclip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

loader = AutoLoader(
task_name="txt_img_matching",
model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
model_dir="./checkpoints"
)

Expand All @@ -23,11 +23,11 @@ def inference():
image = Image.open("./examples/AltCLIP-m18//dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
padding=True,
truncation=True,
max_length=77,
return_tensors='pt')
tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
padding=True,
truncation=True,
max_length=77,
return_tensors='pt')

text = tokenizer_out["input_ids"].to(device)
attention_mask = tokenizer_out["attention_mask"].to(device)
Expand All @@ -38,5 +38,5 @@ def inference():

print(text_probs.cpu().numpy()[0].tolist())

if __name__=="__main__":
inference()
if __name__ == "__main__":
inference()
15 changes: 10 additions & 5 deletions examples/AltDiffusion-m18/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
model.eval()
model.to(device)
predictor = Predictor(model)
prompt = "Daenerys Targaryen as a mermeid with a piercing gaze wearing an enchanted bikini in an underwater magical forest, highly detailed face, realistic face, beautiful detailed eyes, fantasy art, in the style of artgerm, illustration, epic, fantasy, intricate, hyper detailed, artstation, concept art, smooth, sharp focus, ray tracing, vibrant, photorealistic"
negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, extra head, extra legs,fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
prompt = "สาวสวย"
# negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, extra head, extra legs,fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
seed = 553124
predictor.predict_generate_images(
prompt=prompt,negative_prompt=negative_prompt,seed=seed
)

result = predictor.predict_generate_images(
prompt=prompt,
# negative_prompt=negative_prompt,
outpath="./AltDiffusionOutputs",
ddim_steps=50,
seed=seed)
print(type(result), result)
179 changes: 102 additions & 77 deletions examples/AltDiffusion/dreambooth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
#
# https://huggingface.co/spaces/BAAI/dreambooth-altdiffusion/blob/main/train_dreambooth.py

import os
import sys
import random
Expand All @@ -12,66 +6,73 @@

import torch
from torch.utils.data import Dataset
from torch.cuda.amp import autocast, GradScaler # 混合精度训练
from PIL import Image
from torchvision import transforms

from flagai.trainer import Trainer
from flagai.auto_model.auto_loader import AutoLoader

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

instance_data_dir = "./examples/AltDiffusion/instance_images"
instance_prompt = "<鸣人>男孩"

with_prior_preservation = False
class_data_dir = "Mix"
class_prompt = "男孩"
prior_loss_weight = 1.0
num_class_images = 10
resolution = 512
center_crop = True

train_text_encoder = False
train_only_unet = True

num_train_epochs = 500
batch_size = 4
learning_rate = 5e-6
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-2
adam_epsilon = 1e-08

# =============== 训练参数配置 ===============
instance_data_dir = "./instance_images" # 实例图片目录
instance_prompt = "smile😁" # 实例提示词(包含特殊标识符)

with_prior_preservation = False # 是否使用先验保留
class_data_dir = "Mix" # 类别图片目录
class_prompt = "smile" # 类别提示词
prior_loss_weight = 1.0 # 先验损失权重
num_class_images = 4 # 类别图片数量
resolution = 128 # 图片分辨率
center_crop = False # 是否中心裁剪

train_text_encoder = False # 是否训练文本编码器
train_only_unet = True # 是否仅训练UNet

# 训练超参数
num_train_epochs = 10 # 训练轮数
batch_size = 2 # 批次大小
learning_rate = 5e-6 # 学习率
adam_beta1 = 0.9 # Adam优化器参数
adam_beta2 = 0.999 # Adam优化器参数
adam_weight_decay = 1e-2 # 权重衰减
adam_epsilon = 1e-08 # 数值稳定性常数

# =============== 模型初始化 ===============
# 加载AltDiffusion-m18文本生成图像模型
auto_loader = AutoLoader(task_name="text2img",
model_name="AltDiffusion")
model_name="AltDiffusion-m18")

model = auto_loader.get_model()
tokenizer = model.tokenizer
model = auto_loader.get_model() # 获取模型
tokenizer = model.tokenizer # 获取分词器

# =============== 数据集定义 ===============
class DreamBoothDataset(Dataset):
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
self,
instance_data_root, # 实例图片路径
instance_prompt, # 实例提示词
tokenizer, # 分词器
class_data_root=None, # 类别图片路径
class_prompt=None, # 类别提示词
size=512, # 图片大小
center_crop=False, # 是否中心裁剪
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer

# 处理实例图片
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images

# 处理类别图片(用于先验保留)
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
Expand All @@ -83,12 +84,13 @@ def __init__(
else:
self.class_data_root = None

# 图片预处理流程
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.Normalize([0.5], [0.5]), # 归一化到[-1, 1]
]
)

Expand All @@ -97,21 +99,23 @@ def __len__(self):

def __getitem__(self, index):
example = {}
# 加载实例图片
path = self.instance_images_path[index % self.num_instance_images]
instance_image = Image.open(path)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")


# 应用预处理并获取提示词token
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
print('*'*20, "instance_prompt=", self.instance_prompt)
example["caption"] = self.instance_prompt

# 加载类别图片(如果使用先验保留)
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
Expand All @@ -126,6 +130,7 @@ def __getitem__(self, index):

return example

# =============== 创建数据集和数据加载器 ===============
train_dataset = DreamBoothDataset(
instance_data_root=instance_data_dir,
instance_prompt=instance_prompt,
Expand All @@ -137,12 +142,12 @@ def __getitem__(self, index):
)

def collate_fn(examples):
# 批量处理数据
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
captions = [example["caption"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
# 合并实例和类别数据(如果使用先验保留)
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
Expand All @@ -156,61 +161,81 @@ def collate_fn(examples):
"input_ids": input_ids,
"pixel_values": pixel_values,
"caption": captions,
"txt": captions
}
return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)

vae = model.first_stage_model
text_encoder = model.cond_stage_model
unet = model.model.diffusion_model
# =============== 模型组件准备 ===============
vae = model.first_stage_model # 变分自编码器
text_encoder = model.cond_stage_model # 文本编码器
unet = model.model.diffusion_model # 扩散模型(UNet)

# 冻结不需要训练的组件
vae.requires_grad_(False)
if not train_text_encoder:
text_encoder.requires_grad_(False)

# =============== 优化器设置 ===============
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters(),
text_encoder.parameters()) if train_text_encoder else unet.parameters())
optimizer = optimizer_class(
params_to_optimize,
lr=learning_rate,
betas=(adam_beta1, adam_beta2),
weight_decay=adam_weight_decay,
eps=adam_epsilon,
params_to_optimize,
lr=learning_rate,
betas=(adam_beta1, adam_beta2),
weight_decay=adam_weight_decay,
eps=adam_epsilon,
)

# =============== 混合精度训练 ===============
scaler = GradScaler() # 梯度缩放器(用于混合精度训练)

# 将模型移到设备
model.to(device)
vae = model.first_stage_model.to(device)
text_encoder = model.cond_stage_model.to(device)
unet = model.model.diffusion_model.to(device)

# 训练循环
for epoch in range(num_train_epochs):
unet.train()
if train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
#x = batch["pixel_values"].to(device)
#x = model.encode_first_stage(batch["pixel_values"]).to(device)
#c = batch["caption"]
# 获取输入数据
x, c = model.get_input(batch, "pixel_values")

if with_prior_preservation:
x, x_prior = torch.chunk(x, 2, dim=0)
c, c_prior = torch.chunk(c, 2, dim=0)
loss, _ = model(x, c)
prior_loss, _ = model(x_prior, c_prior)
loss = loss + prior_loss_weight * prior_loss
else:
loss, _ = model(x, c)

print('*'*20, "loss=", str(loss.detach().item()))

loss.backward()
optimizer.step()
optimizer.zero_grad()

## mkdir ./checkpoints/DreamBooth and copy ./checkpoints/AltDiffusion to ./checkpoints/DreamBooth/AltDiffusion
## overwrite model.ckpt for latter usage
chekpoint_path = './checkpoints/DreamBooth/AltDiffusion/model.ckpt'
torch.save(model.state_dict(), chekpoint_path)

# 混合精度训练
with autocast(): # 自动混合精度上下文
if with_prior_preservation:
# 分离实例和先验数据
x, x_prior = torch.chunk(x, 2, dim=0)
c, c_prior = torch.chunk(c, 2, dim=0)
# 计算实例损失和先验损失
loss, _ = model(x, c)
prior_loss, _ = model(x_prior, c_prior)
# 组合损失
loss = loss + prior_loss_weight * prior_loss
else:
# 仅计算实例损失
loss, _ = model(x, c)

print('*' * 20, "loss=", str(loss.detach().item()))

# 反向传播和优化
scaler.scale(loss).backward() # 缩放梯度
scaler.step(optimizer) # 更新参数
scaler.update() # 更新缩放器
optimizer.zero_grad() # 清零梯度

# =============== 保存训练好的模型 ===============
checkpoint_path = './checkpoints/AltDiffusion-m18-new-trained/model.ckpt'
# 确保目录存在
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(model.state_dict(), checkpoint_path)
print(f"Model saved to {checkpoint_path}")
12 changes: 9 additions & 3 deletions examples/AltDiffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loader = AutoLoader(task_name="text2img", #contrastive learning
model_name="AltDiffusion-m9",
model_name="AltDiffusion-m18", # use m18 to do the experiment
model_dir="./checkpoints",
fp16=False)

Expand All @@ -18,5 +18,11 @@
model.to(device)
predictor = Predictor(model)
predictor.predict_generate_images(
"Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation"
)
prompt="smile😁",
# negative_prompt=negative_prompt,
# outpath="./AltDiffusionOutputs",
ddim_steps=20,
plms=True,
skip_grid=True, # or False if you want a grid image
)

Binary file added examples/AltDiffusion/instance_images/1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 9 additions & 2 deletions flagai/model/mm/AltCLIP.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import Tuple, Any, Optional, Union

from transformers.models.clip.modeling_clip import *
import torch.nn as nn
import torch
from transformers.models.clip.modeling_clip import CLIPOutput
from transformers import CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPOutput, CLIPVisionTransformer, clip_loss
from transformers import CLIPProcessor, CLIPConfig, CLIPVisionConfig
import os

from transformers.models.clip.modeling_flax_clip import CLIP_VISION_INPUTS_DOCSTRING, CLIP_INPUTS_DOCSTRING
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings

from flagai.model.base_model import BaseModel
from dataclasses import dataclass

from .modeling_berts import BertSeriesConfig, RobertaSeriesConfig, BertSeriesModelWithTransformation, RobertaSeriesModelWithTransformation

Expand Down
Loading