Skip to content

chore: fix linting #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ SwissArmyTransformer/

trainingdata
temp

.cursor

*.outputs
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ repos:
pass_filenames: false
language: system
stages: [pre-commit]
# - id: linting
# name: linting
# entry: poetry run lint
# pass_filenames: false
# language: system
# stages: [commit]
- id: linting
name: linting
entry: poetry run lint
pass_filenames: false
language: system
stages: [commit]
# - id: type-checking
# name: type checking
# entry: poetry run type-check
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,6 @@ module = [
]
ignore_missing_imports = true

[tool.ruff]
[tool.ruff.lint]
select = ["E", "F", "C90"]
ignore = []
ignore = ["E501", "C901"]
12 changes: 10 additions & 2 deletions scripts/inference_flux_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from diffusers import FluxPipeline

from videotuna.utils.inference_utils import load_prompts_from_txt


Expand Down Expand Up @@ -56,9 +57,16 @@ def inference(args):
parser.add_argument(
"--model_type", type=str, default="dev", choices=["dev", "schnell"]
)
parser.add_argument("--prompt", type=str, default="A photo of a cat", help="Inference prompt, string or path to a .txt file")
parser.add_argument(
"--prompt",
type=str,
default="A photo of a cat",
help="Inference prompt, string or path to a .txt file",
)
parser.add_argument("--out_path", type=str, default="./results/t2i/image.png")
parser.add_argument("--lora_path", type=str, default=None, help="Full path to lora weights")
parser.add_argument(
"--lora_path", type=str, default=None, help="Full path to lora weights"
)
parser.add_argument("--width", type=int, default=1360)
parser.add_argument("--height", type=int, default=768)
parser.add_argument("--num_inference_steps", type=int, default=4)
Expand Down
5 changes: 5 additions & 0 deletions scripts/train_flux_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
logger = logging.getLogger("SimpleTuner")
logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))


def add_timestamp_to_output_dir(output_dir):
time_str = time.strftime("%Y%m%d%H%M%S")
folder_name = output_dir.stem
Expand All @@ -33,6 +34,7 @@ def add_timestamp_to_output_dir(output_dir):
output_dir = output_dir.parent / folder_name
return str(output_dir)


def config_process(config):
# add timestamp to the output_dir
output_dir = Path(config["--output_dir"])
Expand All @@ -42,6 +44,7 @@ def config_process(config):
json.dump(config, f, indent=4)
return config


def load_yaml_config(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
Expand All @@ -58,6 +61,7 @@ def load_yaml_config(config_path):

return config, data_config_json


def load_json_config(config_path, data_config_path):
# load config files
with open(config_path) as f:
Expand All @@ -68,6 +72,7 @@ def load_json_config(config_path, data_config_path):
config = config_process(config)
return config, data_config


def main(args):
try:
import multiprocessing
Expand Down
3 changes: 2 additions & 1 deletion tests/datasets/test_dataset_from_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def test_video_dataset_from_csv_with_split(self):
print(f"len(dataset): {len(val_dataset)}")
self.assertLessEqual(len(val_dataset), 128)
self.assertEqual(val_dataset[0]["video"].shape[2], 256)
# Check if the sum of the lengths of the training and validation datasets is equal to the total number of samples
# Check if the sum of the lengths of the training and validation datasets
# is equal to the total number of samples
self.assertEqual(len(train_dataset) + len(val_dataset), 128)


Expand Down
16 changes: 10 additions & 6 deletions videotuna/base/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model, schedule="linear", **kwargs):
self.counter = 0

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if isinstance(attr, torch.Tensor):
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
Expand All @@ -37,7 +37,9 @@ def make_schedule(
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)

self.register_buffer("betas", to_torch(self.model.diffusion_scheduler.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
Expand Down Expand Up @@ -133,22 +135,24 @@ def sample(
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
except Exception:
try:
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
except:
except Exception:
cbs = int(
conditioning[list(conditioning.keys())[0]][0]["y"].shape[0]
)

if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
f"Warning: Got {cbs} conditionings but "
f"batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
f"Warning: Got {conditioning.shape[0]} conditionings but "
f"batch-size is {batch_size}"
)

self.make_schedule(
Expand Down
13 changes: 7 additions & 6 deletions videotuna/base/ddim_multiplecond.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import copy

import numpy as np
import torch
from tqdm import tqdm
Expand All @@ -21,7 +19,7 @@ def __init__(self, model, schedule="linear", **kwargs):
self.counter = 0

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if isinstance(attr, torch.Tensor):
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
Expand All @@ -39,7 +37,9 @@ def make_schedule(
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)

if self.model.use_scale:
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
Expand Down Expand Up @@ -120,7 +120,8 @@ def sample(
fs=None,
timestep_spacing="uniform", # uniform_trailing for starting from last timestep
guidance_rescale=0.0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
# this has to come in the same format as the conditioning,
# e.g. as encoded tokens, ...
**kwargs,
):

Expand All @@ -129,7 +130,7 @@ def sample(
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
except Exception:
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]

if cbs != batch_size:
Expand Down
29 changes: 10 additions & 19 deletions videotuna/base/ddpm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,37 @@
"""

import logging
import os
import random
from contextlib import contextmanager
from functools import partial

import numpy as np
from einops import rearrange, repeat
from tqdm import tqdm

mainlogger = logging.getLogger("mainlogger")

import peft
import pytorch_lightning as pl
import torch
import torch.nn as nn
from einops import rearrange, repeat
from pytorch_lightning.utilities import rank_zero_only
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torchvision.utils import make_grid
from tqdm import tqdm

from videotuna.base.ddim import DDIMSampler
from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl
from videotuna.base.distributions import DiagonalGaussianDistribution
from videotuna.base.ema import LitEma
from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr

# import rlhf utils
from videotuna.lvdm.models.rlhf_utils.batch_ddim import batch_ddim_sampling
from videotuna.lvdm.models.rlhf_utils.reward_fn import aesthetic_loss_fn
from videotuna.lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler
from videotuna.lvdm.modules.utils import (
default,
disabled_train,
exists,
extract_into_tensor,
noise_like,
)
from videotuna.lvdm.modules.utils import default, disabled_train, extract_into_tensor
from videotuna.utils.common_utils import instantiate_from_config

__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}


mainlogger = logging.getLogger("mainlogger")


class DDPMFlow(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
Expand Down Expand Up @@ -430,7 +421,7 @@ def load_lora_from_ckpt(self, model, path):
f"Parameter {key} from lora_state_dict was not copied to the model."
)
# print(f"Parameter {key} from lora_state_dict was not copied to the model.")
print(f"All Parameters was copied successfully.")
print("All Parameters was copied successfully.")

def inject_lora(self):
"""inject lora into the denoising module.
Expand Down Expand Up @@ -519,7 +510,7 @@ def __init__(

try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
Expand Down Expand Up @@ -1586,7 +1577,7 @@ def configure_optimizers(self):

if self.cond_stage_trainable:
params_cond_stage = [
p for p in self.cond_stage_model.parameters() if p.requires_grad == True
p for p in self.cond_stage_model.parameters() if p.requires_grad is True
]
mainlogger.info(
f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model."
Expand Down
1 change: 0 additions & 1 deletion videotuna/base/diffusion_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
from videotuna.lvdm.modules.utils import (
default,
disabled_train,
exists,
extract_into_tensor,
noise_like,
Expand Down
4 changes: 2 additions & 2 deletions videotuna/base/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(self, model):
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name

def copy_to(self, model):
m_param = dict(model.named_parameters())
Expand All @@ -58,7 +58,7 @@ def copy_to(self, model):
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name

def store(self, parameters):
"""
Expand Down
24 changes: 8 additions & 16 deletions videotuna/base/iddpm3d.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,23 @@
import enum
import logging
import math
import os
import random
from contextlib import contextmanager
from functools import partial

import numpy as np
from einops import rearrange, repeat
from omegaconf.listconfig import ListConfig
from tqdm import tqdm

mainlogger = logging.getLogger("mainlogger")

import torch
import torch.nn as nn
from einops import rearrange
from omegaconf.listconfig import ListConfig
from pytorch_lightning.utilities import rank_zero_only
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torchvision.utils import make_grid
from tqdm import tqdm

from videotuna.base.ddim import DDIMSampler
from videotuna.base.ddpm3d import DDPMFlow
from videotuna.base.diffusion_schedulers import DDPMScheduler
from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl
from videotuna.base.utils_diffusion import (
discretized_gaussian_log_likelihood,
make_beta_schedule,
rescale_zero_terminal_snr,
)
from videotuna.base.utils_diffusion import discretized_gaussian_log_likelihood
from videotuna.lvdm.modules.utils import (
default,
disabled_train,
Expand All @@ -37,6 +27,8 @@
)
from videotuna.utils.common_utils import instantiate_from_config

mainlogger = logging.getLogger("mainlogger")


def mean_flat(tensor: torch.Tensor, mask=None) -> torch.Tensor:
"""
Expand Down Expand Up @@ -1039,7 +1031,7 @@ def __init__(

try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
Expand Down Expand Up @@ -1305,7 +1297,7 @@ def apply_model(self, x_noisy, t, cond, **kwargs):
key: [cond["c_crossattn"][0]["y"]],
"mask": [cond["c_crossattn"][0]["mask"]],
}
except:
except Exception:
cond = {key: [cond["y"]], "mask": [cond["mask"]]} # support mask for T5
else:
if isinstance(cond, dict):
Expand Down
1 change: 0 additions & 1 deletion videotuna/base/utils_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat


Expand Down
Loading