Skip to content

Commit 22fc5a7

Browse files
author
Vincent Dupont
committed
chore: fix linting
1 parent be459a7 commit 22fc5a7

File tree

149 files changed

+454
-675
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+454
-675
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ SwissArmyTransformer/
2424

2525
trainingdata
2626
temp
27-
27+
.cursor
2828

2929
*.outputs

.pre-commit-config.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ repos:
1010
pass_filenames: false
1111
language: system
1212
stages: [pre-commit]
13-
# - id: linting
14-
# name: linting
15-
# entry: poetry run lint
16-
# pass_filenames: false
17-
# language: system
18-
# stages: [commit]
13+
- id: linting
14+
name: linting
15+
entry: poetry run lint
16+
pass_filenames: false
17+
language: system
18+
stages: [commit]
1919
# - id: type-checking
2020
# name: type checking
2121
# entry: poetry run type-check

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,6 @@ module = [
125125
]
126126
ignore_missing_imports = true
127127

128-
[tool.ruff]
128+
[tool.ruff.lint]
129129
select = ["E", "F", "C90"]
130-
ignore = []
130+
ignore = ["E501", "C901"]

scripts/inference_flux_lora.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from diffusers import FluxPipeline
6+
67
from videotuna.utils.inference_utils import load_prompts_from_txt
78

89

@@ -56,9 +57,16 @@ def inference(args):
5657
parser.add_argument(
5758
"--model_type", type=str, default="dev", choices=["dev", "schnell"]
5859
)
59-
parser.add_argument("--prompt", type=str, default="A photo of a cat", help="Inference prompt, string or path to a .txt file")
60+
parser.add_argument(
61+
"--prompt",
62+
type=str,
63+
default="A photo of a cat",
64+
help="Inference prompt, string or path to a .txt file",
65+
)
6066
parser.add_argument("--out_path", type=str, default="./results/t2i/image.png")
61-
parser.add_argument("--lora_path", type=str, default=None, help="Full path to lora weights")
67+
parser.add_argument(
68+
"--lora_path", type=str, default=None, help="Full path to lora weights"
69+
)
6270
parser.add_argument("--width", type=int, default=1360)
6371
parser.add_argument("--height", type=int, default=768)
6472
parser.add_argument("--num_inference_steps", type=int, default=4)

scripts/train_flux_lora.py

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
logger = logging.getLogger("SimpleTuner")
2424
logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
2525

26+
2627
def add_timestamp_to_output_dir(output_dir):
2728
time_str = time.strftime("%Y%m%d%H%M%S")
2829
folder_name = output_dir.stem
@@ -33,6 +34,7 @@ def add_timestamp_to_output_dir(output_dir):
3334
output_dir = output_dir.parent / folder_name
3435
return str(output_dir)
3536

37+
3638
def config_process(config):
3739
# add timestamp to the output_dir
3840
output_dir = Path(config["--output_dir"])
@@ -42,6 +44,7 @@ def config_process(config):
4244
json.dump(config, f, indent=4)
4345
return config
4446

47+
4548
def load_yaml_config(config_path):
4649
with open(config_path) as f:
4750
config = yaml.safe_load(f)
@@ -58,6 +61,7 @@ def load_yaml_config(config_path):
5861

5962
return config, data_config_json
6063

64+
6165
def load_json_config(config_path, data_config_path):
6266
# load config files
6367
with open(config_path) as f:
@@ -68,6 +72,7 @@ def load_json_config(config_path, data_config_path):
6872
config = config_process(config)
6973
return config, data_config
7074

75+
7176
def main(args):
7277
try:
7378
import multiprocessing

tests/datasets/test_dataset_from_csv.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def test_video_dataset_from_csv_with_split(self):
240240
print(f"len(dataset): {len(val_dataset)}")
241241
self.assertLessEqual(len(val_dataset), 128)
242242
self.assertEqual(val_dataset[0]["video"].shape[2], 256)
243-
# Check if the sum of the lengths of the training and validation datasets is equal to the total number of samples
243+
# Check if the sum of the lengths of the training and validation datasets
244+
# is equal to the total number of samples
244245
self.assertEqual(len(train_dataset) + len(val_dataset), 128)
245246

246247

videotuna/base/ddim.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, model, schedule="linear", **kwargs):
1919
self.counter = 0
2020

2121
def register_buffer(self, name, attr):
22-
if type(attr) == torch.Tensor:
22+
if isinstance(attr, torch.Tensor):
2323
if attr.device != torch.device("cuda"):
2424
attr = attr.to(torch.device("cuda"))
2525
setattr(self, name, attr)
@@ -37,7 +37,9 @@ def make_schedule(
3737
assert (
3838
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
3939
), "alphas have to be defined for each timestep"
40-
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
40+
41+
def to_torch(x):
42+
return x.clone().detach().to(torch.float32).to(self.model.device)
4143

4244
self.register_buffer("betas", to_torch(self.model.diffusion_scheduler.betas))
4345
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
@@ -133,22 +135,24 @@ def sample(
133135
if isinstance(conditioning, dict):
134136
try:
135137
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
136-
except:
138+
except Exception:
137139
try:
138140
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
139-
except:
141+
except Exception:
140142
cbs = int(
141143
conditioning[list(conditioning.keys())[0]][0]["y"].shape[0]
142144
)
143145

144146
if cbs != batch_size:
145147
print(
146-
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
148+
f"Warning: Got {cbs} conditionings but "
149+
f"batch-size is {batch_size}"
147150
)
148151
else:
149152
if conditioning.shape[0] != batch_size:
150153
print(
151-
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
154+
f"Warning: Got {conditioning.shape[0]} conditionings but "
155+
f"batch-size is {batch_size}"
152156
)
153157

154158
self.make_schedule(

videotuna/base/ddim_multiplecond.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import copy
2-
31
import numpy as np
42
import torch
53
from tqdm import tqdm
@@ -21,7 +19,7 @@ def __init__(self, model, schedule="linear", **kwargs):
2119
self.counter = 0
2220

2321
def register_buffer(self, name, attr):
24-
if type(attr) == torch.Tensor:
22+
if isinstance(attr, torch.Tensor):
2523
if attr.device != torch.device("cuda"):
2624
attr = attr.to(torch.device("cuda"))
2725
setattr(self, name, attr)
@@ -39,7 +37,9 @@ def make_schedule(
3937
assert (
4038
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
4139
), "alphas have to be defined for each timestep"
42-
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
40+
41+
def to_torch(x):
42+
return x.clone().detach().to(torch.float32).to(self.model.device)
4343

4444
if self.model.use_scale:
4545
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
@@ -120,7 +120,8 @@ def sample(
120120
fs=None,
121121
timestep_spacing="uniform", # uniform_trailing for starting from last timestep
122122
guidance_rescale=0.0,
123-
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
123+
# this has to come in the same format as the conditioning,
124+
# e.g. as encoded tokens, ...
124125
**kwargs,
125126
):
126127

@@ -129,7 +130,7 @@ def sample(
129130
if isinstance(conditioning, dict):
130131
try:
131132
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
132-
except:
133+
except Exception:
133134
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
134135

135136
if cbs != batch_size:

videotuna/base/ddpm3d.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -7,46 +7,37 @@
77
"""
88

99
import logging
10-
import os
1110
import random
1211
from contextlib import contextmanager
1312
from functools import partial
1413

1514
import numpy as np
16-
from einops import rearrange, repeat
17-
from tqdm import tqdm
18-
19-
mainlogger = logging.getLogger("mainlogger")
20-
2115
import peft
2216
import pytorch_lightning as pl
2317
import torch
24-
import torch.nn as nn
18+
from einops import rearrange, repeat
2519
from pytorch_lightning.utilities import rank_zero_only
2620
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
2721
from torchvision.utils import make_grid
22+
from tqdm import tqdm
2823

2924
from videotuna.base.ddim import DDIMSampler
30-
from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl
25+
from videotuna.base.distributions import DiagonalGaussianDistribution
3126
from videotuna.base.ema import LitEma
32-
from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
3327

3428
# import rlhf utils
3529
from videotuna.lvdm.models.rlhf_utils.batch_ddim import batch_ddim_sampling
3630
from videotuna.lvdm.models.rlhf_utils.reward_fn import aesthetic_loss_fn
3731
from videotuna.lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler
38-
from videotuna.lvdm.modules.utils import (
39-
default,
40-
disabled_train,
41-
exists,
42-
extract_into_tensor,
43-
noise_like,
44-
)
32+
from videotuna.lvdm.modules.utils import default, disabled_train, extract_into_tensor
4533
from videotuna.utils.common_utils import instantiate_from_config
4634

4735
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
4836

4937

38+
mainlogger = logging.getLogger("mainlogger")
39+
40+
5041
class DDPMFlow(pl.LightningModule):
5142
# classic DDPM with Gaussian diffusion, in image space
5243
def __init__(
@@ -430,7 +421,7 @@ def load_lora_from_ckpt(self, model, path):
430421
f"Parameter {key} from lora_state_dict was not copied to the model."
431422
)
432423
# print(f"Parameter {key} from lora_state_dict was not copied to the model.")
433-
print(f"All Parameters was copied successfully.")
424+
print("All Parameters was copied successfully.")
434425

435426
def inject_lora(self):
436427
"""inject lora into the denoising module.
@@ -519,7 +510,7 @@ def __init__(
519510

520511
try:
521512
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
522-
except:
513+
except Exception:
523514
self.num_downs = 0
524515
if not scale_by_std:
525516
self.scale_factor = scale_factor
@@ -1586,7 +1577,7 @@ def configure_optimizers(self):
15861577

15871578
if self.cond_stage_trainable:
15881579
params_cond_stage = [
1589-
p for p in self.cond_stage_model.parameters() if p.requires_grad == True
1580+
p for p in self.cond_stage_model.parameters() if p.requires_grad is True
15901581
]
15911582
mainlogger.info(
15921583
f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model."

videotuna/base/diffusion_schedulers.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
99
from videotuna.lvdm.modules.utils import (
1010
default,
11-
disabled_train,
1211
exists,
1312
extract_into_tensor,
1413
noise_like,

videotuna/base/ema.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def forward(self, model):
4949
one_minus_decay * (shadow_params[sname] - m_param[key])
5050
)
5151
else:
52-
assert not key in self.m_name2s_name
52+
assert key not in self.m_name2s_name
5353

5454
def copy_to(self, model):
5555
m_param = dict(model.named_parameters())
@@ -58,7 +58,7 @@ def copy_to(self, model):
5858
if m_param[key].requires_grad:
5959
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
6060
else:
61-
assert not key in self.m_name2s_name
61+
assert key not in self.m_name2s_name
6262

6363
def store(self, parameters):
6464
"""

videotuna/base/iddpm3d.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,23 @@
11
import enum
22
import logging
33
import math
4-
import os
54
import random
6-
from contextlib import contextmanager
75
from functools import partial
86

97
import numpy as np
10-
from einops import rearrange, repeat
11-
from omegaconf.listconfig import ListConfig
12-
from tqdm import tqdm
13-
14-
mainlogger = logging.getLogger("mainlogger")
15-
168
import torch
17-
import torch.nn as nn
9+
from einops import rearrange
10+
from omegaconf.listconfig import ListConfig
1811
from pytorch_lightning.utilities import rank_zero_only
1912
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
2013
from torchvision.utils import make_grid
14+
from tqdm import tqdm
2115

2216
from videotuna.base.ddim import DDIMSampler
2317
from videotuna.base.ddpm3d import DDPMFlow
2418
from videotuna.base.diffusion_schedulers import DDPMScheduler
2519
from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl
26-
from videotuna.base.utils_diffusion import (
27-
discretized_gaussian_log_likelihood,
28-
make_beta_schedule,
29-
rescale_zero_terminal_snr,
30-
)
20+
from videotuna.base.utils_diffusion import discretized_gaussian_log_likelihood
3121
from videotuna.lvdm.modules.utils import (
3222
default,
3323
disabled_train,
@@ -37,6 +27,8 @@
3727
)
3828
from videotuna.utils.common_utils import instantiate_from_config
3929

30+
mainlogger = logging.getLogger("mainlogger")
31+
4032

4133
def mean_flat(tensor: torch.Tensor, mask=None) -> torch.Tensor:
4234
"""
@@ -1039,7 +1031,7 @@ def __init__(
10391031

10401032
try:
10411033
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
1042-
except:
1034+
except Exception:
10431035
self.num_downs = 0
10441036
if not scale_by_std:
10451037
self.scale_factor = scale_factor
@@ -1305,7 +1297,7 @@ def apply_model(self, x_noisy, t, cond, **kwargs):
13051297
key: [cond["c_crossattn"][0]["y"]],
13061298
"mask": [cond["c_crossattn"][0]["mask"]],
13071299
}
1308-
except:
1300+
except Exception:
13091301
cond = {key: [cond["y"]], "mask": [cond["mask"]]} # support mask for T5
13101302
else:
13111303
if isinstance(cond, dict):

videotuna/base/utils_diffusion.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import torch
5-
import torch.nn.functional as F
65
from einops import repeat
76

87

0 commit comments

Comments
 (0)