Skip to content

Commit 888b45f

Browse files
authored
[feat] Add QLoRA & multi-resolution packing support (#26)
* [feat] Implement multi-resolution packing and add QLoRA support - Implement multi-resolution packing for CogView4 to improve training efficiency - Add QLoRA support for both CogView and CogVideo models - Refactor trainers to fix training bugs and optimize the computation pipeline - Update dataset utilities and fine-tuning base components This update significantly improves model training efficiency and flexibility.
1 parent d041ff8 commit 888b45f

37 files changed

+1235
-705
lines changed

gradio/gradio_infer_demo.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def update_task(hf_model_id: str) -> Tuple[gr.Dropdown, gr.Component]:
119119
def update_subcheckpoints(checkpoint_dir):
120120
"""Get subdirectories for the selected checkpoint directory."""
121121
if checkpoint_dir == "None":
122-
return gr.Dropdown(choices=[], interactive=False, visible=False)
122+
return gr.Dropdown(choices=["None"], value="None", interactive=False, visible=False)
123123

124124
# Get the full path to the checkpoint directory
125125
full_checkpoint_path = os.path.join(checkpoint_rootdir, checkpoint_dir)
@@ -138,7 +138,7 @@ def update_subcheckpoints(checkpoint_dir):
138138

139139
if not subdirs:
140140
# If there are no subdirectories, hide the dropdown
141-
return gr.Dropdown(choices=[], interactive=False, visible=False)
141+
return gr.Dropdown(choices=["None"], value="None", interactive=False, visible=False)
142142

143143
# Show dropdown with available subdirectories
144144
return gr.Dropdown(
@@ -183,6 +183,7 @@ def load_model_and_generate(
183183
)
184184

185185
# Load LoRA weights if selected
186+
unload_lora_checkpoint(pipeline)
186187
if lora_checkpoint != "None":
187188
progress(0.3, desc="Loading LoRA weights...")
188189
# Construct the full path to the specific checkpoint
@@ -192,8 +193,6 @@ def load_model_and_generate(
192193
lora_path = lora_checkpoint
193194
logger.info(f"Loading LoRA weights from {lora_path}")
194195
load_lora_checkpoint(pipeline, lora_path)
195-
else:
196-
unload_lora_checkpoint(pipeline)
197196

198197
# Generate content based on task
199198
progress(0.5, desc="Generating content...")
@@ -300,7 +299,7 @@ def load_model_and_generate(
300299
guidance_scale = gr.Slider(
301300
minimum=1.0,
302301
maximum=15.0,
303-
value=6.0,
302+
value=5.0,
304303
step=0.1,
305304
label="Guidance Scale",
306305
info="Higher values increase prompt adherence",

gradio/gradio_lora_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torchvision.io import write_video
1212
from utils import (
1313
BaseTask,
14-
flatten_dict,
1514
get_dataset_dirs,
1615
get_logger,
1716
get_lora_checkpoint_rootdir,
@@ -24,6 +23,7 @@
2423

2524
import gradio as gr
2625
from cogkit import GenerationMode, guess_generation_mode
26+
from cogkit.utils import flatten_dict
2727

2828
# ======================= global state ====================
2929

gradio/utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
resolve_path,
99
)
1010
from .logging import get_logger
11-
from .misc import flatten_dict, get_resolutions
11+
from .misc import get_resolutions
1212
from .task import BaseTask
1313

1414
__all__ = [
@@ -22,5 +22,4 @@
2222
"resolve_path",
2323
"BaseTask",
2424
"get_resolutions",
25-
"flatten_dict",
2625
]

gradio/utils/misc.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Any, Dict, List
2-
31
from cogkit import GenerationMode
42

53

6-
def get_resolutions(task: GenerationMode) -> List[str]:
4+
def get_resolutions(task: GenerationMode) -> list[str]:
75
if task == GenerationMode.TextToImage:
86
return [
97
"512x512",
@@ -19,35 +17,3 @@ def get_resolutions(task: GenerationMode) -> List[str]:
1917
"49x480x720",
2018
"81x768x1360",
2119
]
22-
23-
24-
def flatten_dict(d: Dict[str, Any], ignore_none: bool = False) -> Dict[str, Any]:
25-
"""
26-
Flattens a nested dictionary into a single layer dictionary.
27-
28-
Args:
29-
d: The dictionary to flatten
30-
ignore_none: If True, keys with None values will be omitted
31-
32-
Returns:
33-
A flattened dictionary
34-
35-
Raises:
36-
ValueError: If there are duplicate keys across nested dictionaries
37-
"""
38-
result = {}
39-
40-
def _flatten(current_dict, result_dict):
41-
for key, value in current_dict.items():
42-
if value is None and ignore_none:
43-
continue
44-
45-
if isinstance(value, dict):
46-
_flatten(value, result_dict)
47-
else:
48-
if key in result_dict:
49-
raise ValueError(f"Duplicate key '{key}' found in nested dictionary")
50-
result_dict[key] = value
51-
52-
_flatten(d, result)
53-
return result

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ finetune = [
3333
"datasets~=3.4",
3434
"deepspeed~=0.16.4",
3535
"av~=14.2.0",
36+
"bitsandbytes~=0.45.4",
37+
"tensorboard~=2.19",
3638
]
3739

3840
[project.urls]

quickstart/scripts/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def main():
77
parser = argparse.ArgumentParser()
88
parser.add_argument("--model_name", type=str, required=True)
99
parser.add_argument("--training_type", type=str, required=True)
10-
parser.add_argument("--enable_packing", action="store_true")
10+
parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true")
1111
args, unknown = parser.parse_known_args()
1212

1313
trainer_cls = get_model_cls(args.model_name, args.training_type, args.enable_packing)

quickstart/scripts/train_ddp_i2v.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ OUTPUT_ARGS=(
2121
# Data Configuration
2222
DATA_ARGS=(
2323
--data_root "/path/to/data"
24-
25-
# Note:
26-
# for CogVideoX series models, number of training frames should be **8N+1**
27-
# for CogVideoX1.5 series models, number of training frames should be **16N+1**
28-
--train_resolution "81x768x1360" # (frames x height x width)
2924
)
3025

3126
# Training Configuration
@@ -35,13 +30,18 @@ TRAIN_ARGS=(
3530
--batch_size 1
3631
--gradient_accumulation_steps 1
3732
--mixed_precision "bf16" # ["no", "fp16"]
38-
--learning_rate 2e-5
33+
--learning_rate 5e-5
34+
35+
# Note:
36+
# for CogVideoX series models, number of training frames should be **8N+1**
37+
# for CogVideoX1.5 series models, number of training frames should be **16N+1**
38+
--train_resolution "81x768x1360" # (frames x height x width)
3939
)
4040

4141
# System Configuration
4242
SYSTEM_ARGS=(
4343
--num_workers 8
44-
--pin_memory True
44+
--pin_memory true
4545
--nccl_timeout 1800
4646
)
4747

quickstart/scripts/train_ddp_t2i.sh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,38 @@ OUTPUT_ARGS=(
2121
# Data Configuration
2222
DATA_ARGS=(
2323
--data_root "/path/to/data"
24-
25-
# Note:
26-
# For CogView4 series models, height and width should be **32N** (multiple of 32)
27-
--train_resolution "1024x1024" # (height x width)
2824
)
2925

3026
# Training Configuration
3127
TRAIN_ARGS=(
3228
--seed 42 # random seed
3329
--train_epochs 1 # number of training epochs
3430
--batch_size 1
31+
3532
--gradient_accumulation_steps 1
33+
34+
# Note: For CogView4 series models, height and width should be **32N** (multiple of 32)
35+
--train_resolution "1024x1024" # (height x width)
36+
37+
# When enable_packing is true, training will use the native image resolution
38+
# (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio).
39+
#
40+
# IMPORTANT: When changing enable_packing from true to false (or vice versa),
41+
# make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist.
42+
--enable_packing false
43+
3644
--mixed_precision "bf16" # ["no", "fp16"]
37-
--learning_rate 2e-5
45+
--learning_rate 5e-5
46+
47+
# enable --low_vram will slow down validation speed and enable quantization during training
48+
# Note: --low_vram currently does not support multi-GPU training
49+
--low_vram false
3850
)
3951

4052
# System Configuration
4153
SYSTEM_ARGS=(
4254
--num_workers 8
43-
--pin_memory True
55+
--pin_memory true
4456
--nccl_timeout 1800
4557
)
4658

quickstart/scripts/train_ddp_t2v.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ OUTPUT_ARGS=(
2020
# Data Configuration
2121
DATA_ARGS=(
2222
--data_root "/path/to/data"
23-
24-
# Note:
25-
# for CogVideoX series models, number of training frames should be **8N+1**
26-
# for CogVideoX1.5 series models, number of training frames should be **16N+1**
27-
--train_resolution "81x768x1360" # (frames x height x width)
2823
)
2924

3025
# Training Configuration
@@ -34,13 +29,18 @@ TRAIN_ARGS=(
3429
--batch_size 1
3530
--gradient_accumulation_steps 1
3631
--mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training
37-
--learning_rate 2e-5
32+
--learning_rate 5e-5
33+
34+
# Note:
35+
# for CogVideoX series models, number of training frames should be **8N+1**
36+
# for CogVideoX1.5 series models, number of training frames should be **16N+1**
37+
--train_resolution "81x768x1360" # (frames x height x width)
3838
)
3939

4040
# System Configuration
4141
SYSTEM_ARGS=(
4242
--num_workers 8
43-
--pin_memory True
43+
--pin_memory true
4444
--nccl_timeout 1800
4545
)
4646

quickstart/scripts/train_zero_i2v.sh

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,32 @@ OUTPUT_ARGS=(
2020
# Data Configuration
2121
DATA_ARGS=(
2222
--data_root "/path/to/data"
23-
24-
# Note:
25-
# for CogVideoX series models, number of training frames should be **8N+1**
26-
# for CogVideoX1.5 series models, number of training frames should be **16N+1**
27-
--train_resolution "81x768x1360" # (frames x height x width)
2823
)
2924

3025
# Training Configuration
3126
TRAIN_ARGS=(
3227
--seed 42 # random seed
3328
--train_epochs 1 # number of training epochs
3429

35-
--learning_rate 2e-5
30+
--learning_rate 5e-5
3631

3732
######### Please keep consistent with deepspeed config file ##########
3833
--batch_size 1
3934
--gradient_accumulation_steps 1
4035
--mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training
4136
########################################################################
37+
38+
# Note:
39+
# for CogVideoX series models, number of training frames should be **8N+1**
40+
# for CogVideoX1.5 series models, number of training frames should be **16N+1**
41+
--train_resolution "81x768x1360" # (frames x height x width)
42+
4243
)
4344

4445
# System Configuration
4546
SYSTEM_ARGS=(
4647
--num_workers 8
47-
--pin_memory True
48+
--pin_memory true
4849
--nccl_timeout 1800
4950
)
5051

0 commit comments

Comments
 (0)