Skip to content

Commit 8e0246b

Browse files
authored
[feat] Add gradio demo for inference & lora finetuning (#24)
* [deps] Add gradio & slugify * [chore] Add gradio related rules * [feat] Add gradio demo for inference & lora fintuning * [deps] Change slugify version requirement * [chore] Remove misleading annotations * [chore] Reuse `get_resolutions`
1 parent 0e9762d commit 8e0246b

File tree

13 files changed

+1250
-1
lines changed

13 files changed

+1250
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,6 @@ tmp/
263263

264264
webdoc/
265265
**/wandb/
266+
267+
**/lora_checkpoints
268+
**/.gradio

gradio/configs/t2i.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Model Configuration
2+
model:
3+
model_path: "THUDM/CogView4-6B" # Path to the pre-trained model
4+
model_name: "cogview4-6b" # Model name (options: "cogview4-6b")
5+
model_type: "t2i" # Model type (text-to-image)
6+
training_type: "lora" # Training type
7+
8+
# Output Configuration
9+
output:
10+
output_dir: "/path/to/output" # Directory to save outputs
11+
report_to: "tensorboard" # Logging framework
12+
13+
# Data Configuration
14+
data:
15+
data_root: "/path/to/data" # Path to training data
16+
17+
# Training Configuration
18+
training:
19+
seed: 42 # Random seed for reproducibility
20+
train_epochs: 1 # Number of training epochs
21+
batch_size: 1 # Batch size per GPU
22+
gradient_accumulation_steps: 1 # Number of gradient accumulation steps
23+
mixed_precision: "bf16" # Mixed precision mode (options: "no", "fp16", "bf16")
24+
learning_rate: 2.0e-5 # Learning rate
25+
26+
# Note: For CogView4 series models, height and width should be **32N** (multiple of 32)
27+
train_resolution: "1024x1024" # Training resolution (height x width)
28+
29+
# System Configuration
30+
system:
31+
num_workers: 8 # Number of dataloader workers
32+
pin_memory: true # Whether to pin memory in dataloader
33+
nccl_timeout: 1800 # NCCL timeout in seconds
34+
35+
# Checkpointing Configuration
36+
checkpoint:
37+
checkpointing_steps: 10 # Save checkpoint every x steps
38+
checkpointing_limit: 2 # Maximum number of checkpoints to keep
39+
40+
# Validation Configuration
41+
validation:
42+
do_validation: true # Whether to perform validation
43+
validation_steps: 10 # Validate every x steps (should be multiple of checkpointing_steps)

gradio/configs/t2v.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Model Configuration
2+
model:
3+
model_path: "THUDM/CogVideoX1.5-5B" # Path to the pre-trained model
4+
model_name: "cogvideox1.5-t2v" # Model name (options: "cogview4-6b")
5+
model_type: "t2v" # Model type (text-to-video)
6+
training_type: "lora" # Training type
7+
8+
# Output Configuration
9+
output:
10+
output_dir: "/path/to/output" # Directory to save outputs
11+
report_to: "tensorboard" # Logging framework
12+
13+
# Data Configuration
14+
data:
15+
data_root: "/path/to/data" # Path to training data
16+
17+
# Training Configuration
18+
training:
19+
seed: 42 # Random seed for reproducibility
20+
train_epochs: 1 # Number of training epochs
21+
batch_size: 1 # Batch size per GPU
22+
gradient_accumulation_steps: 1 # Number of gradient accumulation steps
23+
mixed_precision: "bf16" # Mixed precision mode (options: "no", "fp16", "bf16")
24+
learning_rate: 2.0e-5 # Learning rate
25+
26+
# Note: For CogView4 series models, height and width should be **32N** (multiple of 32)
27+
train_resolution: "81x768x1360" # Training resolution (height x width)
28+
29+
# System Configuration
30+
system:
31+
num_workers: 8 # Number of dataloader workers
32+
pin_memory: true # Whether to pin memory in dataloader
33+
nccl_timeout: 1800 # NCCL timeout in seconds
34+
35+
# Checkpointing Configuration
36+
checkpoint:
37+
checkpointing_steps: 10 # Save checkpoint every x steps
38+
checkpointing_limit: 2 # Maximum number of checkpoints to keep
39+
40+
# Validation Configuration
41+
validation:
42+
do_validation: true # Whether to perform validation
43+
validation_steps: 10 # Validate every x steps (should be multiple of checkpointing_steps)

0 commit comments

Comments
 (0)