Skip to content

Commit 5ad85ab

Browse files
Merge pull request #6 from zRzRzRzRzRzRzR/dev-lhy
[refactor] Move training scripts out of src directory
2 parents 589950f + 1226163 commit 5ad85ab

File tree

19 files changed

+60
-62
lines changed

19 files changed

+60
-62
lines changed

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ api = [
3131
"pydantic-settings~=2.8",
3232
"python-dotenv~=1.0",
3333
]
34-
video = [
35-
"decord~=0.6.0",
36-
"opencv-python-headless~=4.11",
37-
]
3834

3935
# TODO: adds project urls
4036
# [project.urls]

src/cogkit/finetune/diffusion/accelerate_config.yaml renamed to quickstart/configs/accelerate_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ gpu_ids: "0,1,2,3,4,5,6,7"
44
num_processes: 8 # should be the same as the number of GPUs
55

66
# gpu_ids: "0"
7-
# num_processes: 1 # should be the same as the number of GPUs
7+
# num_processes: 1
88

99
debug: false
1010

1111
distributed_type: DEEPSPEED
1212
deepspeed_config:
13-
deepspeed_config_file: ../configs/zero/zero3.yaml # e.g. configs/zero2.yaml, need use absolute path
13+
deepspeed_config_file: /path/to/configs/zero/zero2.yaml # e.g. need use absolute path
1414
zero3_init_flag: false
1515

1616
downcast_bf16: 'no'
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
import argparse
2-
import sys
3-
from pathlib import Path
4-
5-
sys.path.append(str(Path(__file__).parent.parent))
62

73
from cogkit.finetune import get_model_cls
84

quickstart/scripts/train_ddp_t2i.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ CHECKPOINT_ARGS=(
5555
VALIDATION_ARGS=(
5656
--do_validation true # ["true", "false"]
5757
--validation_steps 10 # should be multiple of checkpointing_steps
58-
--gen_fps 16
5958
)
6059

6160
# Combine all arguments and launch training

quickstart/scripts/train_zero_i2v.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ VALIDATION_ARGS=(
6363
)
6464

6565
# Combine all arguments and launch training
66-
accelerate launch --config_file ./accelerate_config.yaml train.py \
66+
accelerate launch --config_file ../configs/accelerate_config.yaml train.py \
6767
"${MODEL_ARGS[@]}" \
6868
"${OUTPUT_ARGS[@]}" \
6969
"${DATA_ARGS[@]}" \

quickstart/scripts/train_zero_t2i.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ VALIDATION_ARGS=(
6262
)
6363

6464
# Combine all arguments and launch training
65-
accelerate launch --config_file ./accelerate_config.yaml train.py \
65+
accelerate launch --config_file ../configs/accelerate_config.yaml train.py \
6666
"${MODEL_ARGS[@]}" \
6767
"${OUTPUT_ARGS[@]}" \
6868
"${DATA_ARGS[@]}" \

quickstart/scripts/train_zero_t2v.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ VALIDATION_ARGS=(
6363
)
6464

6565
# Combine all arguments and launch training
66-
accelerate launch --config_file ./accelerate_config.yaml train.py \
66+
accelerate launch --config_file ../configs/accelerate_config.yaml train.py \
6767
"${MODEL_ARGS[@]}" \
6868
"${OUTPUT_ARGS[@]}" \
6969
"${DATA_ARGS[@]}" \

src/cogkit/datasets/i2v_dataset.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717

1818
from .utils import (
1919
get_prompt_embedding,
20-
load_images,
21-
load_images_from_videos,
22-
load_prompts,
23-
load_videos,
2420
preprocess_image_with_resize,
2521
preprocess_video_with_resize,
2622
)

src/cogkit/datasets/t2i_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
import torch
77
from accelerate.logging import get_logger
88
from datasets import load_dataset
9-
from PIL import Image
109
from torch.utils.data import Dataset
11-
from torchvision import transforms
1210
from typing_extensions import override
1311

14-
from cogmodels.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
12+
from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
1513

1614
from .utils import (
1715
preprocess_image_with_resize,
@@ -20,7 +18,7 @@
2018
)
2119

2220
if TYPE_CHECKING:
23-
from cogmodels.finetune.diffusion.trainer import DiffusionTrainer
21+
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
2422

2523
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
2624
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.

src/cogkit/datasets/t2v_dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import hashlib
21
from pathlib import Path
32
from typing import TYPE_CHECKING, Any
43

@@ -11,12 +10,12 @@
1110
from torchvision import transforms
1211
from typing_extensions import override
1312

14-
from cogmodels.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
13+
from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
1514

16-
from .utils import load_prompts, load_videos, preprocess_video_with_resize, get_prompt_embedding
15+
from .utils import get_prompt_embedding, preprocess_video_with_resize
1716

1817
if TYPE_CHECKING:
19-
from cogmodels.finetune.diffusion.trainer import DiffusionTrainer
18+
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
2019

2120
logger = get_logger(LOG_NAME, LOG_LEVEL)
2221

0 commit comments

Comments
 (0)