diff --git a/.gitignore b/.gitignore index ae6b9fb4..99a83d85 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,6 @@ SwissArmyTransformer/ trainingdata temp - +.cursor *.outputs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2e18a6d..0b2dba6b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,12 +10,12 @@ repos: pass_filenames: false language: system stages: [pre-commit] -# - id: linting -# name: linting -# entry: poetry run lint -# pass_filenames: false -# language: system -# stages: [commit] + - id: linting + name: linting + entry: poetry run lint + pass_filenames: false + language: system + stages: [commit] # - id: type-checking # name: type checking # entry: poetry run type-check diff --git a/pyproject.toml b/pyproject.toml index f6e1910d..71cf76c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,6 @@ module = [ ] ignore_missing_imports = true -[tool.ruff] +[tool.ruff.lint] select = ["E", "F", "C90"] -ignore = [] +ignore = ["E501", "C901"] diff --git a/scripts/inference_flux_lora.py b/scripts/inference_flux_lora.py index b864750d..bddf5553 100644 --- a/scripts/inference_flux_lora.py +++ b/scripts/inference_flux_lora.py @@ -3,6 +3,7 @@ import torch from diffusers import FluxPipeline + from videotuna.utils.inference_utils import load_prompts_from_txt @@ -56,9 +57,16 @@ def inference(args): parser.add_argument( "--model_type", type=str, default="dev", choices=["dev", "schnell"] ) - parser.add_argument("--prompt", type=str, default="A photo of a cat", help="Inference prompt, string or path to a .txt file") + parser.add_argument( + "--prompt", + type=str, + default="A photo of a cat", + help="Inference prompt, string or path to a .txt file", + ) parser.add_argument("--out_path", type=str, default="./results/t2i/image.png") - parser.add_argument("--lora_path", type=str, default=None, help="Full path to lora weights") + parser.add_argument( + "--lora_path", type=str, default=None, help="Full path to lora weights" + ) parser.add_argument("--width", type=int, default=1360) parser.add_argument("--height", type=int, default=768) parser.add_argument("--num_inference_steps", type=int, default=4) diff --git a/scripts/train_flux_lora.py b/scripts/train_flux_lora.py index 3ded7d19..606a5463 100644 --- a/scripts/train_flux_lora.py +++ b/scripts/train_flux_lora.py @@ -23,6 +23,7 @@ logger = logging.getLogger("SimpleTuner") logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) + def add_timestamp_to_output_dir(output_dir): time_str = time.strftime("%Y%m%d%H%M%S") folder_name = output_dir.stem @@ -33,6 +34,7 @@ def add_timestamp_to_output_dir(output_dir): output_dir = output_dir.parent / folder_name return str(output_dir) + def config_process(config): # add timestamp to the output_dir output_dir = Path(config["--output_dir"]) @@ -42,6 +44,7 @@ def config_process(config): json.dump(config, f, indent=4) return config + def load_yaml_config(config_path): with open(config_path) as f: config = yaml.safe_load(f) @@ -58,6 +61,7 @@ def load_yaml_config(config_path): return config, data_config_json + def load_json_config(config_path, data_config_path): # load config files with open(config_path) as f: @@ -68,6 +72,7 @@ def load_json_config(config_path, data_config_path): config = config_process(config) return config, data_config + def main(args): try: import multiprocessing diff --git a/tests/datasets/test_dataset_from_csv.py b/tests/datasets/test_dataset_from_csv.py index 3d0fa1c1..5237a772 100644 --- a/tests/datasets/test_dataset_from_csv.py +++ b/tests/datasets/test_dataset_from_csv.py @@ -240,7 +240,8 @@ def test_video_dataset_from_csv_with_split(self): print(f"len(dataset): {len(val_dataset)}") self.assertLessEqual(len(val_dataset), 128) self.assertEqual(val_dataset[0]["video"].shape[2], 256) - # Check if the sum of the lengths of the training and validation datasets is equal to the total number of samples + # Check if the sum of the lengths of the training and validation datasets + # is equal to the total number of samples self.assertEqual(len(train_dataset) + len(val_dataset), 128) diff --git a/videotuna/base/ddim.py b/videotuna/base/ddim.py index 40d54452..a99316a4 100644 --- a/videotuna/base/ddim.py +++ b/videotuna/base/ddim.py @@ -19,7 +19,7 @@ def __init__(self, model, schedule="linear", **kwargs): self.counter = 0 def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: + if isinstance(attr, torch.Tensor): if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) @@ -37,7 +37,9 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + def to_torch(x): + return x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer("betas", to_torch(self.model.diffusion_scheduler.betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) @@ -133,22 +135,24 @@ def sample( if isinstance(conditioning, dict): try: cbs = conditioning[list(conditioning.keys())[0]].shape[0] - except: + except Exception: try: cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] - except: + except Exception: cbs = int( conditioning[list(conditioning.keys())[0]][0]["y"].shape[0] ) if cbs != batch_size: print( - f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + f"Warning: Got {cbs} conditionings but " + f"batch-size is {batch_size}" ) else: if conditioning.shape[0] != batch_size: print( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + f"Warning: Got {conditioning.shape[0]} conditionings but " + f"batch-size is {batch_size}" ) self.make_schedule( diff --git a/videotuna/base/ddim_multiplecond.py b/videotuna/base/ddim_multiplecond.py index db49e6de..bad21c63 100644 --- a/videotuna/base/ddim_multiplecond.py +++ b/videotuna/base/ddim_multiplecond.py @@ -1,5 +1,3 @@ -import copy - import numpy as np import torch from tqdm import tqdm @@ -21,7 +19,7 @@ def __init__(self, model, schedule="linear", **kwargs): self.counter = 0 def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: + if isinstance(attr, torch.Tensor): if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) @@ -39,7 +37,9 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + def to_torch(x): + return x.clone().detach().to(torch.float32).to(self.model.device) if self.model.use_scale: self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] @@ -120,7 +120,8 @@ def sample( fs=None, timestep_spacing="uniform", # uniform_trailing for starting from last timestep guidance_rescale=0.0, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + # this has to come in the same format as the conditioning, + # e.g. as encoded tokens, ... **kwargs, ): @@ -129,7 +130,7 @@ def sample( if isinstance(conditioning, dict): try: cbs = conditioning[list(conditioning.keys())[0]].shape[0] - except: + except Exception: cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] if cbs != batch_size: diff --git a/videotuna/base/ddpm3d.py b/videotuna/base/ddpm3d.py index c8dc7581..98164700 100644 --- a/videotuna/base/ddpm3d.py +++ b/videotuna/base/ddpm3d.py @@ -7,46 +7,37 @@ """ import logging -import os import random from contextlib import contextmanager from functools import partial import numpy as np -from einops import rearrange, repeat -from tqdm import tqdm - -mainlogger = logging.getLogger("mainlogger") - import peft import pytorch_lightning as pl import torch -import torch.nn as nn +from einops import rearrange, repeat from pytorch_lightning.utilities import rank_zero_only from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torchvision.utils import make_grid +from tqdm import tqdm from videotuna.base.ddim import DDIMSampler -from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl +from videotuna.base.distributions import DiagonalGaussianDistribution from videotuna.base.ema import LitEma -from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr # import rlhf utils from videotuna.lvdm.models.rlhf_utils.batch_ddim import batch_ddim_sampling from videotuna.lvdm.models.rlhf_utils.reward_fn import aesthetic_loss_fn from videotuna.lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler -from videotuna.lvdm.modules.utils import ( - default, - disabled_train, - exists, - extract_into_tensor, - noise_like, -) +from videotuna.lvdm.modules.utils import default, disabled_train, extract_into_tensor from videotuna.utils.common_utils import instantiate_from_config __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} +mainlogger = logging.getLogger("mainlogger") + + class DDPMFlow(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space def __init__( @@ -430,7 +421,7 @@ def load_lora_from_ckpt(self, model, path): f"Parameter {key} from lora_state_dict was not copied to the model." ) # print(f"Parameter {key} from lora_state_dict was not copied to the model.") - print(f"All Parameters was copied successfully.") + print("All Parameters was copied successfully.") def inject_lora(self): """inject lora into the denoising module. @@ -519,7 +510,7 @@ def __init__( try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -1586,7 +1577,7 @@ def configure_optimizers(self): if self.cond_stage_trainable: params_cond_stage = [ - p for p in self.cond_stage_model.parameters() if p.requires_grad == True + p for p in self.cond_stage_model.parameters() if p.requires_grad is True ] mainlogger.info( f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model." diff --git a/videotuna/base/diffusion_schedulers.py b/videotuna/base/diffusion_schedulers.py index 8aa87cb4..391bf7f0 100644 --- a/videotuna/base/diffusion_schedulers.py +++ b/videotuna/base/diffusion_schedulers.py @@ -8,7 +8,6 @@ from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr from videotuna.lvdm.modules.utils import ( default, - disabled_train, exists, extract_into_tensor, noise_like, diff --git a/videotuna/base/ema.py b/videotuna/base/ema.py index 0e1447b0..1fbaf62a 100644 --- a/videotuna/base/ema.py +++ b/videotuna/base/ema.py @@ -49,7 +49,7 @@ def forward(self, model): one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -58,7 +58,7 @@ def copy_to(self, model): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/videotuna/base/iddpm3d.py b/videotuna/base/iddpm3d.py index 6d651a2e..3e9ae725 100644 --- a/videotuna/base/iddpm3d.py +++ b/videotuna/base/iddpm3d.py @@ -1,33 +1,23 @@ import enum import logging import math -import os import random -from contextlib import contextmanager from functools import partial import numpy as np -from einops import rearrange, repeat -from omegaconf.listconfig import ListConfig -from tqdm import tqdm - -mainlogger = logging.getLogger("mainlogger") - import torch -import torch.nn as nn +from einops import rearrange +from omegaconf.listconfig import ListConfig from pytorch_lightning.utilities import rank_zero_only from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torchvision.utils import make_grid +from tqdm import tqdm from videotuna.base.ddim import DDIMSampler from videotuna.base.ddpm3d import DDPMFlow from videotuna.base.diffusion_schedulers import DDPMScheduler from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl -from videotuna.base.utils_diffusion import ( - discretized_gaussian_log_likelihood, - make_beta_schedule, - rescale_zero_terminal_snr, -) +from videotuna.base.utils_diffusion import discretized_gaussian_log_likelihood from videotuna.lvdm.modules.utils import ( default, disabled_train, @@ -37,6 +27,8 @@ ) from videotuna.utils.common_utils import instantiate_from_config +mainlogger = logging.getLogger("mainlogger") + def mean_flat(tensor: torch.Tensor, mask=None) -> torch.Tensor: """ @@ -1039,7 +1031,7 @@ def __init__( try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -1305,7 +1297,7 @@ def apply_model(self, x_noisy, t, cond, **kwargs): key: [cond["c_crossattn"][0]["y"]], "mask": [cond["c_crossattn"][0]["mask"]], } - except: + except Exception: cond = {key: [cond["y"]], "mask": [cond["mask"]]} # support mask for T5 else: if isinstance(cond, dict): diff --git a/videotuna/base/utils_diffusion.py b/videotuna/base/utils_diffusion.py index 1b3a4675..bf2fe42e 100644 --- a/videotuna/base/utils_diffusion.py +++ b/videotuna/base/utils_diffusion.py @@ -2,7 +2,6 @@ import numpy as np import torch -import torch.nn.functional as F from einops import repeat diff --git a/videotuna/cogvideo_hf/cogvideo_pl.py b/videotuna/cogvideo_hf/cogvideo_pl.py index 83334cc4..a03de097 100644 --- a/videotuna/cogvideo_hf/cogvideo_pl.py +++ b/videotuna/cogvideo_hf/cogvideo_pl.py @@ -5,21 +5,12 @@ # from videotuna.base.ddpm3d import DDPM import pytorch_lightning as pl import torch -from diffusers import ( - AutoencoderKLCogVideoX, - CogVideoXDPMScheduler, - CogVideoXTransformer3DModel, -) +from diffusers import CogVideoXDPMScheduler from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from peft import ( - LoraConfig, - get_peft_model, - get_peft_model_state_dict, - set_peft_model_state_dict, -) +from peft import get_peft_model from transformers import T5EncoderModel, T5Tokenizer from videotuna.utils.common_utils import instantiate_from_config diff --git a/videotuna/cogvideo_sat/arguments.py b/videotuna/cogvideo_sat/arguments.py index 25562b97..aa617394 100644 --- a/videotuna/cogvideo_sat/arguments.py +++ b/videotuna/cogvideo_sat/arguments.py @@ -1,5 +1,4 @@ import argparse -import json import os import sys import warnings diff --git a/videotuna/cogvideo_sat/data_video.py b/videotuna/cogvideo_sat/data_video.py index 00ea3b48..21cdfcb0 100644 --- a/videotuna/cogvideo_sat/data_video.py +++ b/videotuna/cogvideo_sat/data_video.py @@ -3,6 +3,7 @@ import os import random import sys +import threading from fractions import Fraction from functools import partial from typing import Any, Dict, Optional, Tuple, Union @@ -22,7 +23,7 @@ av, ) from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms.functional import resize def read_video( @@ -207,9 +208,6 @@ def load_video( return pad_last_frame(tensor_frms, num_frames) -import threading - - def load_video_with_timeout(*args, **kwargs): video_container = {} diff --git a/videotuna/cogvideo_sat/diffusion_video.py b/videotuna/cogvideo_sat/diffusion_video.py index e19e03ba..ea4a5575 100644 --- a/videotuna/cogvideo_sat/diffusion_video.py +++ b/videotuna/cogvideo_sat/diffusion_video.py @@ -9,7 +9,6 @@ from sat import mpu from sat.helpers import print_rank0 from sgm.modules import UNCONDITIONAL_CONFIG -from sgm.modules.autoencoding.temporal_ae import VideoDecoder from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from sgm.util import ( default, @@ -109,7 +108,7 @@ def __init__(self, args, **kwargs): def disable_untrainable_params(self): total_trainable = 0 for n, p in self.named_parameters(): - if p.requires_grad == False: + if p.requires_grad is not False: continue flag = False for prefix in self.not_trainable_prefixes: @@ -285,14 +284,15 @@ def sample( scale = None scale_emb = None - denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser( - self.model, - input, - sigma, - c, - concat_images=concat_images, - **addtional_model_inputs, - ) + def denoiser(input, sigma, c, **addtional_model_inputs): + return self.denoiser( + self.model, + input, + sigma, + c, + concat_images=concat_images, + **addtional_model_inputs, + ) samples = self.sampler( denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs diff --git a/videotuna/cogvideo_sat/dit_video_concat.py b/videotuna/cogvideo_sat/dit_video_concat.py index 0654dbde..a0b3510f 100644 --- a/videotuna/cogvideo_sat/dit_video_concat.py +++ b/videotuna/cogvideo_sat/dit_video_concat.py @@ -762,7 +762,7 @@ def __init__( self.inner_hidden_size = hidden_size * 4 try: self.dtype = str_to_dtype[kwargs.pop("dtype")] - except: + except Exception: self.dtype = torch.float32 if use_SwiGLU: diff --git a/videotuna/cogvideo_sat/sgm/__init__.py b/videotuna/cogvideo_sat/sgm/__init__.py index 1c448236..3dc1f76b 100644 --- a/videotuna/cogvideo_sat/sgm/__init__.py +++ b/videotuna/cogvideo_sat/sgm/__init__.py @@ -1,4 +1 @@ -from .models import AutoencodingEngine -from .util import get_configs_path, instantiate_from_config - __version__ = "0.1.0" diff --git a/videotuna/cogvideo_sat/sgm/models/__init__.py b/videotuna/cogvideo_sat/sgm/models/__init__.py index e72b8659..e69de29b 100644 --- a/videotuna/cogvideo_sat/sgm/models/__init__.py +++ b/videotuna/cogvideo_sat/sgm/models/__init__.py @@ -1 +0,0 @@ -from .autoencoder import AutoencodingEngine diff --git a/videotuna/cogvideo_sat/sgm/models/autoencoder.py b/videotuna/cogvideo_sat/sgm/models/autoencoder.py index b18d38ed..683b49a7 100644 --- a/videotuna/cogvideo_sat/sgm/models/autoencoder.py +++ b/videotuna/cogvideo_sat/sgm/models/autoencoder.py @@ -1,27 +1,22 @@ import logging import math -import random import re from abc import abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import pytorch_lightning as pl import torch import torch.distributed -import torch.nn as nn -from einops import rearrange from packaging import version from ..modules.autoencoding.regularizers import AbstractRegularizer -from ..modules.cp_enc_dec import _conv_gather, _conv_split +from ..modules.cp_enc_dec import _conv_split from ..modules.ema import LitEma from ..util import ( default, get_context_parallel_group, get_context_parallel_group_rank, - get_nested_attribute, get_obj_from_str, initialize_context_parallel, instantiate_from_config, diff --git a/videotuna/cogvideo_sat/sgm/modules/__init__.py b/videotuna/cogvideo_sat/sgm/modules/__init__.py index 0db1d771..ec4742d8 100644 --- a/videotuna/cogvideo_sat/sgm/modules/__init__.py +++ b/videotuna/cogvideo_sat/sgm/modules/__init__.py @@ -1,5 +1,3 @@ -from .encoders.modules import GeneralConditioner - UNCONDITIONAL_CONFIG = { "target": "sgm.modules.GeneralConditioner", "params": {"emb_models": []}, diff --git a/videotuna/cogvideo_sat/sgm/modules/attention.py b/videotuna/cogvideo_sat/sgm/modules/attention.py index b122b111..6963adb0 100644 --- a/videotuna/cogvideo_sat/sgm/modules/attention.py +++ b/videotuna/cogvideo_sat/sgm/modules/attention.py @@ -46,7 +46,7 @@ import xformers.ops XFORMERS_IS_AVAILABLE = True -except: +except Exception: XFORMERS_IS_AVAILABLE = False print("no module 'xformers'. Processing without...") diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/__init__.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/__init__.py index b3bb81d9..6b316c7a 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/__init__.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/__init__.py @@ -5,4 +5,3 @@ from .discriminator_loss import GeneralLPIPSWithDiscriminator from .lpips import LatentLPIPS -from .video_loss import VideoAutoencoderLoss diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/discriminator_loss.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/discriminator_loss.py index cefcb1d9..23e807e7 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/discriminator_loss.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/discriminator_loss.py @@ -275,7 +275,6 @@ def forward( f"{split}/loss/nll": nll_loss.detach().mean(), f"{split}/loss/rec": rec_loss.detach().mean(), f"{split}/loss/percep": p_loss.detach().mean(), - f"{split}/loss/rec": rec_loss.detach().mean(), f"{split}/loss/g": g_loss.detach().mean(), f"{split}/scalars/logvar": self.logvar.detach(), f"{split}/scalars/d_weight": d_weight.detach(), diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/video_loss.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/video_loss.py index 93497302..7821c18c 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/video_loss.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/losses/video_loss.py @@ -1,12 +1,11 @@ from math import log2 -from typing import Any, Union +from typing import Any import torch import torch.nn as nn import torch.nn.functional as F -import torchvision from beartype import beartype -from einops import einsum, rearrange, repeat +from einops import einsum, rearrange from einops.layers.torch import Rearrange from kornia.filters import filter3d from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/loss/lpips.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/loss/lpips.py index 3e34f3d0..7b952b8b 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/loss/lpips.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -59,8 +59,8 @@ def forward(self, input, target): for kk in range(len(self.chns)) ] val = res[0] - for l in range(1, len(self.chns)): - val += res[l] + for index in range(1, len(self.chns)): + val += res[index] return val diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/model/model.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/model/model.py index 5d767fcf..94130ca4 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/model/model.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/lpips/model/model.py @@ -10,7 +10,7 @@ def weights_init(m): if classname.find("Conv") != -1: try: nn.init.normal_(m.weight.data, 0.0, 0.02) - except: + except Exception: nn.init.normal_(m.conv.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) @@ -35,9 +35,8 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): norm_layer = nn.BatchNorm2d else: norm_layer = ActNorm - if ( - type(norm_layer) == functools.partial - ): # no need to use bias as BatchNorm2d has affine parameters + if isinstance(n_layers, functools.partial): + # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/magvit2_pytorch.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/magvit2_pytorch.py index 16370ad9..5659639f 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/magvit2_pytorch.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/magvit2_pytorch.py @@ -2,7 +2,7 @@ import pickle from collections import namedtuple from functools import partial, wraps -from math import ceil, log2, sqrt +from math import log2 from pathlib import Path import torch diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/regularizers/__init__.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/regularizers/__init__.py index 6065fb20..b9bf0cca 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/regularizers/__init__.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/regularizers/__init__.py @@ -1,9 +1,6 @@ -from abc import abstractmethod from typing import Any, Tuple import torch -import torch.nn as nn -import torch.nn.functional as F from ....modules.distributions.distributions import DiagonalGaussianDistribution from .base import AbstractRegularizer diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py index 9c3dc8b2..9e35994e 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py @@ -6,7 +6,7 @@ import torch.nn as nn from einops import rearrange -from .movq_enc_3d import CausalConv3d, DownSample3D, Upsample3D +from .movq_enc_3d import CausalConv3d, Upsample3D def cast_tuple(t, length=1): diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py index 008eab2b..db879560 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py @@ -4,12 +4,9 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -from beartype import beartype -from beartype.typing import List, Optional, Tuple, Union from einops import rearrange -from .movq_enc_3d import CausalConv3d, DownSample3D, Upsample3D +from .movq_enc_3d import CausalConv3d, Upsample3D def cast_tuple(t, length=1): diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py index 2596d328..2d8ef5c8 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from beartype import beartype -from beartype.typing import List, Optional, Tuple, Union +from beartype.typing import Tuple, Union from einops import rearrange diff --git a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/quantize.py b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/quantize.py index 96cc56ac..81a64a29 100644 --- a/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/quantize.py +++ b/videotuna/cogvideo_sat/sgm/modules/autoencoding/vqvae/quantize.py @@ -79,8 +79,8 @@ def unmap_to_all(self, inds): def forward(self, z, temp=None, rescale_logits=False, return_logits=False): assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits == False, "Only for interface compatible with Gumbel" - assert return_logits == False, "Only for interface compatible with Gumbel" + assert not rescale_logits, "Only for interface compatible with Gumbel" + assert not return_logits, "Only for interface compatible with Gumbel" # reshape z -> (batch, height, width, channel) and flatten z = rearrange(z, "b c h w -> b h w c").contiguous() z_flattened = z.view(-1, self.e_dim) diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/__init__.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/__init__.py index fccebf95..e69de29b 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/__init__.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/__init__.py @@ -1,6 +0,0 @@ -from .denoiser import Denoiser -from .discretizer import Discretization -from .model import Decoder, Encoder, Model -from .openaimodel import UNetModel -from .sampling import BaseDiffusionSampler -from .wrappers import OpenAIWrapper diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser.py index ffece73c..c764fe4b 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict import torch import torch.nn as nn diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser_scaling.py index 05362a00..bf0dbfcf 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser_scaling.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Tuple +from typing import Tuple import torch diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/guiders.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/guiders.py index 4175e133..8202d399 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/guiders.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/guiders.py @@ -1,13 +1,11 @@ -import logging import math from abc import ABC, abstractmethod from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Tuple import torch -from einops import rearrange, repeat -from ...util import append_dims, default, instantiate_from_config +from ...util import default, instantiate_from_config class Guider(ABC): @@ -28,7 +26,10 @@ class VanillaCFG: def __init__(self, scale, dyn_thresh_config=None): self.scale = scale - scale_schedule = lambda scale, sigma: scale # independent of step + + def scale_schedule(scale, _sigma): + return scale + self.scale_schedule = partial(scale_schedule, scale) self.dyn_thresh = instantiate_from_config( default( @@ -60,10 +61,13 @@ def prepare_inputs(self, x, s, c, uc): class DynamicCFG(VanillaCFG): def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): super().__init__(scale, dyn_thresh_config) - scale_schedule = ( - lambda scale, sigma, step_index: 1 - + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 - ) + + def scale_schedule(scale, _sigma, step_index): + return ( + 1 + + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 + ) + self.scale_schedule = partial(scale_schedule, scale) self.dyn_thresh = instantiate_from_config( default( diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/lora.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/lora.py index d3871eb4..3a93f1dc 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/lora.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/lora.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import List, Optional, Set, Type import torch import torch.nn.functional as F @@ -324,7 +324,7 @@ def _find_modules_v2( while path: try: parent = parent.get_submodule(path.pop(0)) - except: + except Exception: flag = True break if flag: diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/model.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/model.py index 26efd078..3950d3ad 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/model.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/model.py @@ -13,7 +13,7 @@ import xformers.ops XFORMERS_IS_AVAILABLE = True -except: +except Exception: XFORMERS_IS_AVAILABLE = False print("no module 'xformers'. Processing without...") @@ -295,7 +295,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): elif attn_type == "vanilla-xformers": print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": + elif attn_type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/openaimodel.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/openaimodel.py index 167b78e2..0b38d20c 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/openaimodel.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/openaimodel.py @@ -2,7 +2,7 @@ import os from abc import abstractmethod from functools import partial -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional import numpy as np import torch as th @@ -600,7 +600,7 @@ def __init__( assert ( use_spatial_transformer ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - if type(context_dim) == ListConfig: + if isinstance(context_dim, ListConfig): context_dim = list(context_dim) if num_heads_upsample == -1: diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling.py index 7067334d..5fd2958a 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling.py @@ -584,7 +584,7 @@ def denoise( if ofs is not None: additional_model_inputs["ofs"] = ofs - if isinstance(scale, torch.Tensor) == False and scale == 1: + if not isinstance(scale, torch.Tensor) and scale == 1: additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep if scale_emb is not None: additional_model_inputs["scale_emb"] = scale_emb diff --git a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling_utils.py b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling_utils.py index 4c26a75e..8725e70f 100644 --- a/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling_utils.py +++ b/videotuna/cogvideo_sat/sgm/modules/diffusionmodules/sampling_utils.py @@ -23,7 +23,9 @@ def __call__(self, uncond, cond, scale): def dynamic_threshold(x, p=0.95): N, T, C, H, W = x.shape x = rearrange(x, "n t c h w -> n c (t h w)") - l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True) + l, r = x.quantile( # noqa: E741 + q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True + ) s = torch.maximum(-l, r) threshold_mask = (s > 1).expand(-1, -1, H * W * T) if threshold_mask.any(): diff --git a/videotuna/cogvideo_sat/sgm/modules/ema.py b/videotuna/cogvideo_sat/sgm/modules/ema.py index 96f64345..2aa2a47e 100644 --- a/videotuna/cogvideo_sat/sgm/modules/ema.py +++ b/videotuna/cogvideo_sat/sgm/modules/ema.py @@ -53,7 +53,7 @@ def forward(self, model): one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -62,7 +62,7 @@ def copy_to(self, model): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/videotuna/cogvideo_sat/sgm/modules/encoders/modules.py b/videotuna/cogvideo_sat/sgm/modules/encoders/modules.py index bf90110d..a25495f5 100644 --- a/videotuna/cogvideo_sat/sgm/modules/encoders/modules.py +++ b/videotuna/cogvideo_sat/sgm/modules/encoders/modules.py @@ -1,22 +1,14 @@ -import math from contextlib import nullcontext -from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union -import kornia import numpy as np import torch import torch.nn as nn -from einops import rearrange, repeat from omegaconf import ListConfig -from torch.utils.checkpoint import checkpoint from transformers import T5EncoderModel, T5Tokenizer from ...util import ( - append_dims, - autocast, count_params, - default, disabled_train, expand_dims_like, instantiate_from_config, @@ -263,7 +255,7 @@ def __init__( cache_dir=None, ): super().__init__() - if model_dir is not "google/t5-v1_1-xxl": + if model_dir != "google/t5-v1_1-xxl": self.tokenizer = T5Tokenizer.from_pretrained(model_dir) self.transformer = T5EncoderModel.from_pretrained(model_dir) else: diff --git a/videotuna/cogvideo_sat/sgm/modules/video_attention.py b/videotuna/cogvideo_sat/sgm/modules/video_attention.py index 756ae4bf..21082d65 100644 --- a/videotuna/cogvideo_sat/sgm/modules/video_attention.py +++ b/videotuna/cogvideo_sat/sgm/modules/video_attention.py @@ -1,6 +1,16 @@ import torch -from ..modules.attention import * +from ..modules.attention import ( + CrossAttention, + FeedForward, + MemoryEfficientCrossAttention, + SpatialTransformer, + checkpoint, + exists, + nn, + rearrange, + repeat, +) from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding @@ -235,10 +245,10 @@ def __init__( def forward( self, x: torch.Tensor, - context: Optional[torch.Tensor] = None, - time_context: Optional[torch.Tensor] = None, - timesteps: Optional[int] = None, - image_only_indicator: Optional[torch.Tensor] = None, + context: torch.Tensor | None = None, + time_context: torch.Tensor | None = None, + timesteps: int | None = None, + image_only_indicator: torch.Tensor | None = None, ) -> torch.Tensor: _, _, h, w = x.shape x_in = x diff --git a/videotuna/cogvideo_sat/sgm/util.py b/videotuna/cogvideo_sat/sgm/util.py index c85a493f..4a162587 100644 --- a/videotuna/cogvideo_sat/sgm/util.py +++ b/videotuna/cogvideo_sat/sgm/util.py @@ -3,6 +3,7 @@ import os from functools import partial from inspect import isfunction +from math import sqrt import fsspec import numpy as np @@ -128,11 +129,11 @@ def get_string_from_tuple(s): # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple - if type(t) == tuple: + if isinstance(t, tuple): return t[0] else: pass - except: + except Exception: pass return s @@ -270,7 +271,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config, **extra_kwargs): - if not "target" in config: + if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": @@ -381,9 +382,6 @@ def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): return (current_attribute, current_key) if return_key else current_attribute -from math import sqrt - - class SeededNoise: def __init__(self, seeds, weights): self.seeds = seeds diff --git a/videotuna/cogvideo_sat/sgm/webds.py b/videotuna/cogvideo_sat/sgm/webds.py index 078ed6dd..5fdbfc8d 100644 --- a/videotuna/cogvideo_sat/sgm/webds.py +++ b/videotuna/cogvideo_sat/sgm/webds.py @@ -9,7 +9,7 @@ import webdataset as wds from webdataset import DataPipeline, ResampledShards, tarfile_to_samples from webdataset.filters import pipelinefilter -from webdataset.gopen import gopen, gopen_schemes +from webdataset.gopen import Pipe, gopen, gopen_schemes from webdataset.handlers import reraise_exception from webdataset.tariterators import group_by_keys, url_opener @@ -61,7 +61,7 @@ def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True): group = get_data_parallel_group() print_rank0("Using megatron data parallel group.") - except: + except Exception: from sat.mpu import get_data_parallel_group try: @@ -126,7 +126,7 @@ def tar_file_iterator_with_meta( meta_list = [] try: meta_list.append(json.loads(line)) - except Exception as exn: + except Exception: from sat.helpers import print_rank0 print_rank0( @@ -135,7 +135,7 @@ def tar_file_iterator_with_meta( ) continue for item in meta_list: - if not item["key"] in meta_data: + if item["key"] not in meta_data: meta_data[item["key"]] = {} for meta_name in meta_names: if meta_name in item: @@ -332,10 +332,6 @@ def __init__( ) -# rclone support -from webdataset.gopen import Pipe - - def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32): """Open a URL with `curl`. diff --git a/videotuna/cogvideo_sat/vae_modules/attention.py b/videotuna/cogvideo_sat/vae_modules/attention.py index caa594ef..6f80fb23 100644 --- a/videotuna/cogvideo_sat/vae_modules/attention.py +++ b/videotuna/cogvideo_sat/vae_modules/attention.py @@ -46,7 +46,7 @@ import xformers.ops XFORMERS_IS_AVAILABLE = True -except: +except Exception: XFORMERS_IS_AVAILABLE = False print("no module 'xformers'. Processing without...") diff --git a/videotuna/cogvideo_sat/vae_modules/cp_enc_dec.py b/videotuna/cogvideo_sat/vae_modules/cp_enc_dec.py index 6db3e499..ff304339 100644 --- a/videotuna/cogvideo_sat/vae_modules/cp_enc_dec.py +++ b/videotuna/cogvideo_sat/vae_modules/cp_enc_dec.py @@ -5,8 +5,7 @@ import torch.distributed import torch.nn as nn import torch.nn.functional as F -from beartype import beartype -from beartype.typing import List, Optional, Tuple, Union +from beartype.typing import Tuple, Union from einops import rearrange from sgm.util import ( get_context_parallel_group, diff --git a/videotuna/cogvideo_sat/vae_modules/ema.py b/videotuna/cogvideo_sat/vae_modules/ema.py index 96f64345..2aa2a47e 100644 --- a/videotuna/cogvideo_sat/vae_modules/ema.py +++ b/videotuna/cogvideo_sat/vae_modules/ema.py @@ -53,7 +53,7 @@ def forward(self, model): one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -62,7 +62,7 @@ def copy_to(self, model): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/videotuna/cogvideo_sat/vae_modules/utils.py b/videotuna/cogvideo_sat/vae_modules/utils.py index a52d94d6..150d6020 100644 --- a/videotuna/cogvideo_sat/vae_modules/utils.py +++ b/videotuna/cogvideo_sat/vae_modules/utils.py @@ -120,11 +120,11 @@ def get_string_from_tuple(s): # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple - if type(t) == tuple: + if isinstance(t, tuple): return t[0] else: pass - except: + except Exception: pass return s @@ -262,7 +262,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": diff --git a/videotuna/data/cogvideo_dataset.py b/videotuna/data/cogvideo_dataset.py index b1614863..20b72d3d 100644 --- a/videotuna/data/cogvideo_dataset.py +++ b/videotuna/data/cogvideo_dataset.py @@ -1,15 +1,13 @@ -import argparse import logging -import math -import os -import shutil from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Optional import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from torchvision import transforms +logger = logging.getLogger(__name__) + class VideoDataset(Dataset): def __init__( diff --git a/videotuna/data/datasets.py b/videotuna/data/datasets.py index ba16dfc4..12496d22 100644 --- a/videotuna/data/datasets.py +++ b/videotuna/data/datasets.py @@ -4,7 +4,7 @@ sys.path.append(os.getcwd()) import copy import random -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Union import pandas as pd import torch @@ -215,7 +215,7 @@ def __getitem__(self, index): data_item = self.getitem(index) self.safe_data_list.add(index) return data_item - except (ValueError, AssertionError) as e: + except (ValueError, AssertionError): import traceback traceback.print_exc() diff --git a/videotuna/data/datasets_utils.py b/videotuna/data/datasets_utils.py index 2725be6d..03a51f58 100644 --- a/videotuna/data/datasets_utils.py +++ b/videotuna/data/datasets_utils.py @@ -2,15 +2,12 @@ import decord import numpy as np import torch -import torchvision.transforms as transforms from decord import VideoReader, cpu from einops import rearrange from PIL import Image from torchvision.io import write_video from torchvision.utils import save_image -from . import transforms - IMG_EXTS = {"jpg", "bmp", "png", "jpeg", "rgb", "tif"} VIDEO_EXTS = {"mp4", "avi", "mov", "flv", "mkv", "webm", "wmv", "mov"} diff --git a/videotuna/data/lightning_data.py b/videotuna/data/lightning_data.py index 9348b0eb..01826bf4 100644 --- a/videotuna/data/lightning_data.py +++ b/videotuna/data/lightning_data.py @@ -1,5 +1,3 @@ -import argparse -import glob import os import sys from functools import partial @@ -9,12 +7,12 @@ import torch from torch.utils.data import DataLoader, Dataset -os.chdir(sys.path[0]) -sys.path.append("..") - from videotuna.data.base import Txt2ImgIterableBaseDataset from videotuna.utils.common_utils import instantiate_from_config +os.chdir(sys.path[0]) +sys.path.append("..") + def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() @@ -164,7 +162,7 @@ def _test_dataloader(self, shuffle=False): is_iterable_dataset = isinstance( self.datasets["train"], Txt2ImgIterableBaseDataset ) - except: + except Exception: is_iterable_dataset = isinstance( self.datasets["test"], Txt2ImgIterableBaseDataset ) diff --git a/videotuna/data/macvid.py b/videotuna/data/macvid.py index d20810f9..447be4a1 100644 --- a/videotuna/data/macvid.py +++ b/videotuna/data/macvid.py @@ -3,14 +3,11 @@ UNFINISHED UNFINISHED UNFINISHED """ -import glob import json import os import random import decord -import pandas as pd -import torch import yaml from decord import VideoReader, cpu from torch.utils.data import Dataset @@ -94,7 +91,7 @@ def __getitem__(self, index): continue else: break - except: + except Exception: index += 1 print(f"Load video failed! path = {video_path}") return self.__getitem__(index) diff --git a/videotuna/data/rlhf.py b/videotuna/data/rlhf.py index 8eb5d763..84033d8a 100644 --- a/videotuna/data/rlhf.py +++ b/videotuna/data/rlhf.py @@ -1,12 +1,4 @@ -import glob -import json -import os -import random - -import pandas as pd import torch -import yaml -from decord import VideoReader, cpu from torch.utils.data import Dataset diff --git a/videotuna/data/transforms.py b/videotuna/data/transforms.py index 4b8bd22b..9b11f191 100644 --- a/videotuna/data/transforms.py +++ b/videotuna/data/transforms.py @@ -27,7 +27,6 @@ from einops import rearrange from PIL import Image from torchvision.datasets.folder import pil_loader -from torchvision.io import write_video from .datasets_utils import IMG_EXTS, VIDEO_EXTS diff --git a/videotuna/data/webvid_lvdm.py b/videotuna/data/webvid_lvdm.py index 9ab380cc..38a95da9 100644 --- a/videotuna/data/webvid_lvdm.py +++ b/videotuna/data/webvid_lvdm.py @@ -231,7 +231,7 @@ def __getitem__(self, index): continue else: pass - except: + except Exception: index += 1 print(f"Load video failed! path = {video_path}") continue @@ -257,7 +257,7 @@ def __getitem__(self, index): try: frames = video_reader.get_batch(frame_indices) break - except: + except Exception: print(f"Get frames failed! path = {video_path}") index += 1 continue diff --git a/videotuna/hyvideo/config.py b/videotuna/hyvideo/config.py index 24ed4577..6e575346 100644 --- a/videotuna/hyvideo/config.py +++ b/videotuna/hyvideo/config.py @@ -1,7 +1,8 @@ import argparse import re -from .constants import * +import constants + from .modules.models import HUNYUAN_VIDEO_CONFIG @@ -40,7 +41,7 @@ def add_network_args(parser: argparse.ArgumentParser): "--precision", type=str, default="bf16", - choices=PRECISIONS, + choices=constants.PRECISIONS, help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.", ) @@ -61,14 +62,14 @@ def add_extra_models_args(parser: argparse.ArgumentParser): "--vae", type=str, default="884-16c-hy", - choices=list(VAE_PATH), + choices=list(constants.VAE_PATH), help="Name of the VAE model.", ) group.add_argument( "--vae-precision", type=str, default="fp16", - choices=PRECISIONS, + choices=constants.PRECISIONS, help="Precision mode for the VAE model.", ) group.add_argument( @@ -82,14 +83,14 @@ def add_extra_models_args(parser: argparse.ArgumentParser): "--text-encoder", type=str, default="llm", - choices=list(TEXT_ENCODER_PATH), + choices=list(constants.TEXT_ENCODER_PATH), help="Name of the text encoder model.", ) group.add_argument( "--text-encoder-precision", type=str, default="fp16", - choices=PRECISIONS, + choices=constants.PRECISIONS, help="Precision mode for the text encoder model.", ) group.add_argument( @@ -105,21 +106,21 @@ def add_extra_models_args(parser: argparse.ArgumentParser): "--tokenizer", type=str, default="llm", - choices=list(TOKENIZER_PATH), + choices=list(constants.TOKENIZER_PATH), help="Name of the tokenizer model.", ) group.add_argument( "--prompt-template", type=str, default="dit-llm-encode", - choices=PROMPT_TEMPLATE, + choices=constants.PROMPT_TEMPLATE, help="Image prompt template for the decoder-only text encoder model.", ) group.add_argument( "--prompt-template-video", type=str, default="dit-llm-encode-video", - choices=PROMPT_TEMPLATE, + choices=constants.PROMPT_TEMPLATE, help="Video prompt template for the decoder-only text encoder model.", ) group.add_argument( @@ -139,14 +140,14 @@ def add_extra_models_args(parser: argparse.ArgumentParser): "--text-encoder-2", type=str, default="clipL", - choices=list(TEXT_ENCODER_PATH), + choices=list(constants.TEXT_ENCODER_PATH), help="Name of the second text encoder model.", ) group.add_argument( "--text-encoder-precision-2", type=str, default="fp16", - choices=PRECISIONS, + choices=constants.PRECISIONS, help="Precision mode for the second text encoder model.", ) group.add_argument( @@ -159,7 +160,7 @@ def add_extra_models_args(parser: argparse.ArgumentParser): "--tokenizer-2", type=str, default="clipL", - choices=list(TOKENIZER_PATH), + choices=list(constants.TOKENIZER_PATH), help="Name of the second tokenizer model.", ) group.add_argument( diff --git a/videotuna/hyvideo/diffusion/__init__.py b/videotuna/hyvideo/diffusion/__init__.py index 2141aa3d..e69de29b 100644 --- a/videotuna/hyvideo/diffusion/__init__.py +++ b/videotuna/hyvideo/diffusion/__init__.py @@ -1,2 +0,0 @@ -from .pipelines import HunyuanVideoPipeline -from .schedulers import FlowMatchDiscreteScheduler diff --git a/videotuna/hyvideo/diffusion/pipelines/__init__.py b/videotuna/hyvideo/diffusion/pipelines/__init__.py index e44cb619..e69de29b 100644 --- a/videotuna/hyvideo/diffusion/pipelines/__init__.py +++ b/videotuna/hyvideo/diffusion/pipelines/__init__.py @@ -1 +0,0 @@ -from .pipeline_hunyuan_video import HunyuanVideoPipeline diff --git a/videotuna/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py b/videotuna/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py index 904c7da1..027dd386 100644 --- a/videotuna/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/videotuna/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py @@ -40,7 +40,6 @@ unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor -from packaging import version from ...constants import PRECISION_TO_TYPE from ...modules import HYVideoDiffusionTransformer diff --git a/videotuna/hyvideo/diffusion/schedulers/__init__.py b/videotuna/hyvideo/diffusion/schedulers/__init__.py index 14f2ba33..e69de29b 100644 --- a/videotuna/hyvideo/diffusion/schedulers/__init__.py +++ b/videotuna/hyvideo/diffusion/schedulers/__init__.py @@ -1 +0,0 @@ -from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler diff --git a/videotuna/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py b/videotuna/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py index fda6a076..697023c5 100644 --- a/videotuna/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py +++ b/videotuna/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py @@ -20,7 +20,6 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union -import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin diff --git a/videotuna/hyvideo/modules/attenion.py b/videotuna/hyvideo/modules/attenion.py index b8aeff10..70004815 100644 --- a/videotuna/hyvideo/modules/attenion.py +++ b/videotuna/hyvideo/modules/attenion.py @@ -1,8 +1,6 @@ -import importlib.metadata import math import torch -import torch.nn as nn import torch.nn.functional as F try: diff --git a/videotuna/hyvideo/modules/embed_layers.py b/videotuna/hyvideo/modules/embed_layers.py index 917112d4..ecf759ff 100644 --- a/videotuna/hyvideo/modules/embed_layers.py +++ b/videotuna/hyvideo/modules/embed_layers.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from einops import rearrange, repeat from ..utils.helpers import to_2tuple diff --git a/videotuna/hyvideo/modules/models.py b/videotuna/hyvideo/modules/models.py index 1e06e1a0..28b7d50f 100644 --- a/videotuna/hyvideo/modules/models.py +++ b/videotuna/hyvideo/modules/models.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models import ModelMixin from einops import rearrange diff --git a/videotuna/hyvideo/modules/token_refiner.py b/videotuna/hyvideo/modules/token_refiner.py index aa84e972..3adbca9b 100644 --- a/videotuna/hyvideo/modules/token_refiner.py +++ b/videotuna/hyvideo/modules/token_refiner.py @@ -8,7 +8,7 @@ from .attenion import attention from .embed_layers import TextProjection, TimestepEmbedder from .mlp_layers import MLP -from .modulate_layers import apply_gate, modulate +from .modulate_layers import apply_gate from .norm_layers import get_norm_layer diff --git a/videotuna/hyvideo/text_encoder/__init__.py b/videotuna/hyvideo/text_encoder/__init__.py index 2c6773a6..5ad1600c 100644 --- a/videotuna/hyvideo/text_encoder/__init__.py +++ b/videotuna/hyvideo/text_encoder/__init__.py @@ -1,4 +1,3 @@ -from copy import deepcopy from dataclasses import dataclass from typing import Optional, Tuple diff --git a/videotuna/hyvideo/utils/data_utils.py b/videotuna/hyvideo/utils/data_utils.py index be4995b6..1c049c63 100644 --- a/videotuna/hyvideo/utils/data_utils.py +++ b/videotuna/hyvideo/utils/data_utils.py @@ -1,7 +1,5 @@ import math -import numpy as np - def align_to(value, alignment): """align hight, width according to alignment diff --git a/videotuna/lvdm/models/rlhf_utils/actpred_scorer.py b/videotuna/lvdm/models/rlhf_utils/actpred_scorer.py index da5872cb..50660359 100644 --- a/videotuna/lvdm/models/rlhf_utils/actpred_scorer.py +++ b/videotuna/lvdm/models/rlhf_utils/actpred_scorer.py @@ -56,7 +56,7 @@ def mapping_func(x): try: target_class_idx = self.model.config.label2id[target_action] - except: + except Exception: target_class_idx = self.model.config.label2id[mapping_func(target_action)] return target_class_idx diff --git a/videotuna/lvdm/models/rlhf_utils/aesthetic_scorer.py b/videotuna/lvdm/models/rlhf_utils/aesthetic_scorer.py index ae02d848..3cfe14d4 100644 --- a/videotuna/lvdm/models/rlhf_utils/aesthetic_scorer.py +++ b/videotuna/lvdm/models/rlhf_utils/aesthetic_scorer.py @@ -7,8 +7,7 @@ import numpy as np import torch import torch.nn as nn -from PIL import Image -from transformers import CLIPModel, CLIPProcessor +from transformers import CLIPModel # ASSETS_PATH = files("lvdm.models.rlhf_utils.pretrained_reward_models") ASSETS_PATH = "videotuna/lvdm/models/rlhf_utils/pretrained_reward_models" diff --git a/videotuna/lvdm/models/rlhf_utils/batch_ddim.py b/videotuna/lvdm/models/rlhf_utils/batch_ddim.py index 8181ec5c..a54469be 100644 --- a/videotuna/lvdm/models/rlhf_utils/batch_ddim.py +++ b/videotuna/lvdm/models/rlhf_utils/batch_ddim.py @@ -11,12 +11,10 @@ import torch import torchvision from decord import VideoReader, cpu - -sys.path.append("videotuna") from lvdm.models.rlhf_utils.rl_ddim import DDIMSampler +from PIL import Image -# import ipdb -# st = ipdb.set_trace +sys.path.append("videotuna") def batch_ddim_sampling( @@ -95,10 +93,10 @@ def batch_ddim_sampling( try: decode_frame = int(decode_frame) # it's a int - except: + except Exception: pass # modified by haoyu , here we need to distinguish trainable and non-trainable decode. - if type(decode_frame) == int: + if isinstance(decode_frame, int): frame_index = ( random.randint(0, samples.shape[2] - 1) if decode_frame == -1 @@ -153,7 +151,7 @@ def load_checkpoint(model, ckpt, full_strict): for key in state_dict["module"].keys(): new_pl_sd[key[16:]] = state_dict["module"][key] model.load_state_dict(new_pl_sd, strict=full_strict) - except: + except Exception: if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] model.load_state_dict(state_dict, strict=full_strict) @@ -211,9 +209,6 @@ def load_video_batch( return torch.stack(batch_tensor, dim=0) -from PIL import Image - - def load_image_batch(filepath_list, image_size=(256, 256)): batch_tensor = [] for filepath in filepath_list: diff --git a/videotuna/lvdm/models/rlhf_utils/compression_scorer.py b/videotuna/lvdm/models/rlhf_utils/compression_scorer.py index 196197a0..0de92e25 100644 --- a/videotuna/lvdm/models/rlhf_utils/compression_scorer.py +++ b/videotuna/lvdm/models/rlhf_utils/compression_scorer.py @@ -8,7 +8,7 @@ from PIL import Image # import albumentations as A -from transformers import CLIPModel, CLIPProcessor +from transformers import CLIPModel # import ipdb # st = ipdb.set_trace diff --git a/videotuna/lvdm/models/rlhf_utils/reward_fn.py b/videotuna/lvdm/models/rlhf_utils/reward_fn.py index addd9ba7..3da3ef08 100644 --- a/videotuna/lvdm/models/rlhf_utils/reward_fn.py +++ b/videotuna/lvdm/models/rlhf_utils/reward_fn.py @@ -1,27 +1,20 @@ # adapted from VADER https://github.com/mihirp1998/VADER import argparse -import glob -import math import os -import random import sys -import yaml +import torch sys.path.insert( 1, os.path.join(sys.path[0], "..", "..") ) # setting path to get Core and assets import hpsv2 -import lvdm.models.rlhf_utils.prompts as prompts_file import torchvision from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer from lvdm.models.rlhf_utils.actpred_scorer import ActPredScorer from lvdm.models.rlhf_utils.aesthetic_scorer import AestheticScorerDiff -from lvdm.models.rlhf_utils.compression_scorer import ( - JpegCompressionScorer, - jpeg_compressibility, -) +from lvdm.models.rlhf_utils.compression_scorer import JpegCompressionScorer from lvdm.models.rlhf_utils.weather_scorer import WeatherScorer from transformers import ( AutoImageProcessor, @@ -30,7 +23,6 @@ AutoModelForZeroShotObjectDetection, AutoProcessor, ) -from transformers.utils import ContextManagers # import ipdb # st = ipdb.set_trace diff --git a/videotuna/lvdm/models/rlhf_utils/rl_ddim.py b/videotuna/lvdm/models/rlhf_utils/rl_ddim.py index efe13043..144ee1a6 100644 --- a/videotuna/lvdm/models/rlhf_utils/rl_ddim.py +++ b/videotuna/lvdm/models/rlhf_utils/rl_ddim.py @@ -26,7 +26,7 @@ def __init__(self, model, schedule="linear", **kwargs): self.training_mode = False # default def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: + if isinstance(attr, torch.Tensor): if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) @@ -44,7 +44,9 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + def to_torch(x): + return x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer("betas", to_torch(self.model.diffusion_scheduler.betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) @@ -137,7 +139,7 @@ def sample( if isinstance(conditioning, dict): try: cbs = conditioning[list(conditioning.keys())[0]].shape[0] - except: + except Exception: cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] if cbs != batch_size: @@ -249,7 +251,7 @@ def ddim_sampling( init_x0 = False clean_cond = kwargs.pop("clean_cond", False) - if self.training_mode == True: + if self.training_mode is True: # print("Training mode", self.training_mode) if self.backprop_mode == "last": backprop_cutoff_idx = self.ddim_num_steps - 1 @@ -262,7 +264,7 @@ def ddim_sampling( index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) - if self.training_mode == True: + if self.training_mode is True: if i >= backprop_cutoff_idx: for name, param in self.model.named_parameters(): if "lora" in name: diff --git a/videotuna/lvdm/models/rlhf_utils/weather_scorer.py b/videotuna/lvdm/models/rlhf_utils/weather_scorer.py index 8ea0a507..679731b8 100644 --- a/videotuna/lvdm/models/rlhf_utils/weather_scorer.py +++ b/videotuna/lvdm/models/rlhf_utils/weather_scorer.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torchvision -from transformers import CLIPModel, CLIPProcessor +from transformers import CLIPModel class SimpleCNN(nn.Module): # parameter = 6333513 diff --git a/videotuna/lvdm/modules/attention.py b/videotuna/lvdm/modules/attention.py index dfcd9540..f913f6ce 100644 --- a/videotuna/lvdm/modules/attention.py +++ b/videotuna/lvdm/modules/attention.py @@ -10,7 +10,7 @@ import xformers.ops XFORMERS_IS_AVAILBLE = True -except: +except Exception: XFORMERS_IS_AVAILBLE = False from videotuna.lvdm.modules.utils import checkpoint, default, exists, zero_module diff --git a/videotuna/lvdm/modules/encoders/ip_resampler.py b/videotuna/lvdm/modules/encoders/ip_resampler.py index 6718d77f..69954357 100644 --- a/videotuna/lvdm/modules/encoders/ip_resampler.py +++ b/videotuna/lvdm/modules/encoders/ip_resampler.py @@ -80,7 +80,7 @@ def forward(self, x, latents): x = self.norm1(x) latents = self.norm2(latents) - b, l, _ = latents.shape + b, l, _ = latents.shape # noqa: E741 q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) diff --git a/videotuna/lvdm/modules/losses/__init__.py b/videotuna/lvdm/modules/losses/__init__.py index d808f59a..e69de29b 100644 --- a/videotuna/lvdm/modules/losses/__init__.py +++ b/videotuna/lvdm/modules/losses/__init__.py @@ -1 +0,0 @@ -from lvdm.modules.losses.contperceptual import LPIPSWithDiscriminator diff --git a/videotuna/lvdm/modules/networks/openaimodel3d.py b/videotuna/lvdm/modules/networks/openaimodel3d.py index e7aba037..a45df8c5 100644 --- a/videotuna/lvdm/modules/networks/openaimodel3d.py +++ b/videotuna/lvdm/modules/networks/openaimodel3d.py @@ -654,7 +654,7 @@ def forward( emb = self.time_embed(t_emb) if self.fps_cond: - if type(fps) == int: + if isinstance(fps, int): fps = torch.full_like(timesteps, fps) fps_emb = timestep_embedding(fps, self.model_channels, repeat_only=False) emb += self.fps_embedding(fps_emb) diff --git a/videotuna/lvdm/modules/utils.py b/videotuna/lvdm/modules/utils.py index 5ba79d7d..3068a95a 100644 --- a/videotuna/lvdm/modules/utils.py +++ b/videotuna/lvdm/modules/utils.py @@ -13,7 +13,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch import nn from videotuna.utils.common_utils import instantiate_from_config @@ -46,10 +45,14 @@ def extract_into_tensor(a, t, x_shape): def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( - shape[0], *((1,) * (len(shape) - 1)) - ) - noise = lambda: torch.randn(shape, device=device) + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + + def noise(): + return torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/videotuna/lvdm/modules/vae/autoencoder.py b/videotuna/lvdm/modules/vae/autoencoder.py index 2d433588..29c94103 100644 --- a/videotuna/lvdm/modules/vae/autoencoder.py +++ b/videotuna/lvdm/modules/vae/autoencoder.py @@ -1,6 +1,5 @@ import os -import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -41,7 +40,7 @@ def __init__( self.test_args = test_args self.logdir = logdir if colorize_nlabels is not None: - assert type(colorize_nlabels) == int + assert isinstance(colorize_nlabels, int) self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -89,7 +88,7 @@ def init_from_ckpt(self, path, ignore_keys=list()): try: self._cur_epoch = sd["epoch"] sd = sd["state_dict"] - except: + except Exception: self._cur_epoch = "null" keys = list(sd.keys()) for k in keys: diff --git a/videotuna/lvdm/opensoravae.py b/videotuna/lvdm/opensoravae.py index 4ff2a9a8..70c5dd04 100644 --- a/videotuna/lvdm/opensoravae.py +++ b/videotuna/lvdm/opensoravae.py @@ -3,13 +3,9 @@ import pytorch_lightning as pl import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder +from diffusers.models import AutoencoderKL from einops import rearrange -from videotuna.base.distributions import DiagonalGaussianDistribution - class VideoAutoencoderKL(pl.LightningModule): def __init__( diff --git a/videotuna/opensora/models/layers/blocks.py b/videotuna/opensora/models/layers/blocks.py index 7961e4c0..2bfb333c 100644 --- a/videotuna/opensora/models/layers/blocks.py +++ b/videotuna/opensora/models/layers/blocks.py @@ -29,7 +29,9 @@ ) from videotuna.opensora.acceleration.parallel_states import get_sequence_parallel_group -approx_gelu = lambda: nn.GELU(approximate="tanh") + +def approx_geluapprox_gelu(): + return nn.GELU(approximate="tanh") class LlamaRMSNorm(nn.Module): diff --git a/videotuna/opensora/models/stdit/__init__.py b/videotuna/opensora/models/stdit/__init__.py index 7d59db5b..e69de29b 100644 --- a/videotuna/opensora/models/stdit/__init__.py +++ b/videotuna/opensora/models/stdit/__init__.py @@ -1,8 +0,0 @@ -from .stdit import STDiT -from .stdit2 import STDiT2 -from .stdit3 import STDiT3 -from .stdit4 import STDiT4 -from .stdit5 import STDiT5 -from .stdit6 import STDiT6 -from .stdit7 import STDiT7 -from .stdit8 import STDiT8 diff --git a/videotuna/opensora/models/stdit/stdit5.py b/videotuna/opensora/models/stdit/stdit5.py index 30c4b1f6..ed62b955 100644 --- a/videotuna/opensora/models/stdit/stdit5.py +++ b/videotuna/opensora/models/stdit/stdit5.py @@ -8,13 +8,10 @@ from rotary_embedding_torch import RotaryEmbedding from timm.models.layers import DropPath from timm.models.vision_transformer import Mlp -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig from videotuna.opensora.acceleration.checkpoint import auto_grad_checkpoint -from videotuna.opensora.acceleration.communications import ( - gather_forward_split_backward, - split_forward_gather_backward, -) +from videotuna.opensora.acceleration.communications import split_forward_gather_backward from videotuna.opensora.acceleration.parallel_states import get_sequence_parallel_group from videotuna.opensora.models.layers.blocks import ( Attention, diff --git a/videotuna/opensora/models/stdit/stdit6.py b/videotuna/opensora/models/stdit/stdit6.py index 0ffe452d..80364f98 100644 --- a/videotuna/opensora/models/stdit/stdit6.py +++ b/videotuna/opensora/models/stdit/stdit6.py @@ -7,13 +7,10 @@ from einops import rearrange from timm.models.layers import DropPath from timm.models.vision_transformer import Mlp -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig from videotuna.opensora.acceleration.checkpoint import auto_grad_checkpoint -from videotuna.opensora.acceleration.communications import ( - gather_forward_split_backward, - split_forward_gather_backward, -) +from videotuna.opensora.acceleration.communications import split_forward_gather_backward from videotuna.opensora.acceleration.parallel_states import get_sequence_parallel_group from videotuna.opensora.models.layers.blocks import ( Attention, diff --git a/videotuna/opensora/models/stdit/stdit7.py b/videotuna/opensora/models/stdit/stdit7.py index 9f978526..d551b86c 100644 --- a/videotuna/opensora/models/stdit/stdit7.py +++ b/videotuna/opensora/models/stdit/stdit7.py @@ -5,16 +5,12 @@ import torch.distributed as dist import torch.nn as nn from einops import rearrange -from rotary_embedding_torch import RotaryEmbedding from timm.models.layers import DropPath from timm.models.vision_transformer import Mlp -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig from videotuna.opensora.acceleration.checkpoint import auto_grad_checkpoint -from videotuna.opensora.acceleration.communications import ( - gather_forward_split_backward, - split_forward_gather_backward, -) +from videotuna.opensora.acceleration.communications import split_forward_gather_backward from videotuna.opensora.acceleration.parallel_states import get_sequence_parallel_group from videotuna.opensora.models.layers.blocks import ( Attention, diff --git a/videotuna/opensora/models/stdit/stdit8.py b/videotuna/opensora/models/stdit/stdit8.py index a32a9dee..36a66a1b 100644 --- a/videotuna/opensora/models/stdit/stdit8.py +++ b/videotuna/opensora/models/stdit/stdit8.py @@ -29,8 +29,6 @@ T2IFinalLayer, TimestepEmbedder, approx_gelu, - get_1d_sincos_pos_embed, - get_2d_sincos_pos_embed, get_layernorm, t2i_modulate, ) diff --git a/videotuna/opensora/models/stdit/stdit8_debug.py b/videotuna/opensora/models/stdit/stdit8_debug.py index 62c5727e..5c28e19c 100644 --- a/videotuna/opensora/models/stdit/stdit8_debug.py +++ b/videotuna/opensora/models/stdit/stdit8_debug.py @@ -32,12 +32,9 @@ T2IFinalLayer, TimestepEmbedder, approx_gelu, - get_1d_sincos_pos_embed, - get_2d_sincos_pos_embed, get_layernorm, t2i_modulate, ) -from videotuna.opensora.registry import MODELS from videotuna.opensora.utils.ckpt_utils import load_checkpoint diff --git a/videotuna/opensora/models/text_encoder/__init__.py b/videotuna/opensora/models/text_encoder/__init__.py index 9fc6a999..e69de29b 100644 --- a/videotuna/opensora/models/text_encoder/__init__.py +++ b/videotuna/opensora/models/text_encoder/__init__.py @@ -1,3 +0,0 @@ -from .classes import ClassEncoder -from .clip import ClipEncoder -from .t5 import T5Encoder diff --git a/videotuna/opensora/models/text_encoder/t5.py b/videotuna/opensora/models/text_encoder/t5.py index 9544a560..5911e31a 100644 --- a/videotuna/opensora/models/text_encoder/t5.py +++ b/videotuna/opensora/models/text_encoder/t5.py @@ -23,14 +23,12 @@ import html -import os import re import urllib.parse as ul import ftfy import torch from bs4 import BeautifulSoup -from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, T5EncoderModel from videotuna.opensora.registry import MODELS diff --git a/videotuna/opensora/models/vae/__init__.py b/videotuna/opensora/models/vae/__init__.py index f27be475..e69de29b 100644 --- a/videotuna/opensora/models/vae/__init__.py +++ b/videotuna/opensora/models/vae/__init__.py @@ -1,3 +0,0 @@ -from .discriminator import DISCRIMINATOR_3D -from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder -from .vae_temporal import VAE_Temporal diff --git a/videotuna/opensora/models/vae/discriminator.py b/videotuna/opensora/models/vae/discriminator.py index ef10c8eb..3d60a24e 100644 --- a/videotuna/opensora/models/vae/discriminator.py +++ b/videotuna/opensora/models/vae/discriminator.py @@ -173,8 +173,8 @@ def __init__( norm_layer = nn.BatchNorm2d - if ( - type(norm_layer) == functools.partial + if isinstance( + norm_layer, functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: @@ -250,7 +250,7 @@ def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): norm_layer = nn.BatchNorm3d else: raise NotImplementedError("Not implemented.") - if type(norm_layer) == functools.partial: + if isinstance(norm_layer, functools.partial): use_bias = norm_layer.func != nn.BatchNorm3d else: use_bias = norm_layer != nn.BatchNorm3d diff --git a/videotuna/opensora/models/vae/losses.py b/videotuna/opensora/models/vae/losses.py index 1b61c7b3..9685b91f 100644 --- a/videotuna/opensora/models/vae/losses.py +++ b/videotuna/opensora/models/vae/losses.py @@ -63,7 +63,7 @@ def __init__( ): super().__init__() - if type(dtype) == str: + if isinstance(dtype, str): if dtype == "bf16": dtype = torch.bfloat16 elif dtype == "fp16": diff --git a/videotuna/opensora/models/vae/lpips.py b/videotuna/opensora/models/vae/lpips.py index dd28c326..1748869c 100644 --- a/videotuna/opensora/models/vae/lpips.py +++ b/videotuna/opensora/models/vae/lpips.py @@ -94,8 +94,8 @@ def forward(self, input, target): for kk in range(len(self.chns)) ] val = res[0] - for l in range(1, len(self.chns)): - val += res[l] + for index in range(1, len(self.chns)): + val += res[index] return val diff --git a/videotuna/opensora/utils/ckpt_utils.py b/videotuna/opensora/utils/ckpt_utils.py index 5a9ed8d9..cb3085be 100644 --- a/videotuna/opensora/utils/ckpt_utils.py +++ b/videotuna/opensora/utils/ckpt_utils.py @@ -5,7 +5,6 @@ import os from typing import Tuple -import colossalai import torch import torch.distributed as dist import torch.nn as nn @@ -270,7 +269,7 @@ def save_frequently( save_dir: str, shape_dict: dict, ): - save_dir = os.path.join(save_dir, f"last") + save_dir = os.path.join(save_dir, "last") os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) booster.save_model(model, os.path.join(save_dir, "model"), shard=True) diff --git a/videotuna/third_party/flux/caching/text_embeds.py b/videotuna/third_party/flux/caching/text_embeds.py index 171b8feb..c43975d4 100644 --- a/videotuna/third_party/flux/caching/text_embeds.py +++ b/videotuna/third_party/flux/caching/text_embeds.py @@ -926,7 +926,7 @@ def compute_embeddings_for_legacy_prompts( # We attempt to load. logging.debug("Loading embed from cache.") prompt_embeds = self.load_from_cache(filename) - if type(prompt_embeds) is tuple and len(prompt_embeds) == 2: + if isinstance(prompt_embeds, tuple) and len(prompt_embeds) == 2: # we have an attention mask stored with the embed. prompt_embeds, attention_mask = prompt_embeds logging.debug(f"Loaded embeds: {prompt_embeds.shape}") diff --git a/videotuna/third_party/flux/caching/vae.py b/videotuna/third_party/flux/caching/vae.py index 1d26b0fd..5ecafa6e 100644 --- a/videotuna/third_party/flux/caching/vae.py +++ b/videotuna/third_party/flux/caching/vae.py @@ -870,7 +870,7 @@ def read_images_in_batch(self) -> None: " These samples likely do not exist in the storage pool any longer." ) for filepath, element in zip(available_filepaths, batch_output): - if type(filepath) != str: + if not isinstance(filepath, str): raise ValueError( f"Received unknown filepath type ({type(filepath)}) value: {filepath}" ) @@ -879,11 +879,11 @@ def read_images_in_batch(self) -> None: self.process_queue.put((filepath, element, aspect_bucket)) def _process_raw_filepath(self, raw_filepath: str): - if type(raw_filepath) == str or len(raw_filepath) == 1: + if isinstance(raw_filepath, str) or len(raw_filepath) == 1: filepath = raw_filepath elif len(raw_filepath) == 2: basename, filepath = raw_filepath - elif type(raw_filepath) == Path or type(raw_filepath) == numpy_str: + elif isinstance(raw_filepath, Path) or isinstance(raw_filepath, numpy_str): filepath = str(raw_filepath) else: raise ValueError( diff --git a/videotuna/third_party/flux/configuration/cmd_args.py b/videotuna/third_party/flux/configuration/cmd_args.py index 7510830b..d658f36c 100644 --- a/videotuna/third_party/flux/configuration/cmd_args.py +++ b/videotuna/third_party/flux/configuration/cmd_args.py @@ -5,7 +5,6 @@ import sys import time from datetime import timedelta -from typing import Dict, List, Optional, Tuple import torch from accelerate import InitProcessGroupKwargs @@ -16,7 +15,6 @@ from videotuna.third_party.flux.training.optimizer_param import ( is_optimizer_deprecated, is_optimizer_grad_fp32, - map_deprecated_optimizer_parameter, optimizer_choices, ) @@ -1943,7 +1941,7 @@ def parse_cmdline_args(input_args=None): print_on_main_thread(f"{key_val}") try: args = parser.parse_args(input_args) - except: + except Exception: logger.error(f"Could not parse input: {input_args}") import traceback @@ -2216,7 +2214,7 @@ def parse_cmdline_args(input_args=None): f"-!- Flux supports a max length of {model_max_seq_length} tokens, but you have supplied `--i_know_what_i_am_doing`, so this limit will not be enforced. -!-" ) warning_log( - f"The model will begin to collapse after a short period of time, if the model you are continuing from has not been tuned beyond 256 tokens." + "The model will begin to collapse after a short period of time, if the model you are continuing from has not been tuned beyond 256 tokens." ) if flux_version == "dev": if args.validation_num_inference_steps > 28: @@ -2322,7 +2320,6 @@ def parse_cmdline_args(input_args=None): args.disable_accelerator = os.environ.get("SIMPLETUNER_DISABLE_ACCELERATOR", False) if "lycoris" == args.lora_type.lower(): - from lycoris import create_lycoris if args.lycoris_config is None: raise ValueError( diff --git a/videotuna/third_party/flux/configuration/configure.py b/videotuna/third_party/flux/configuration/configure.py index 09640e27..168012b1 100644 --- a/videotuna/third_party/flux/configuration/configure.py +++ b/videotuna/third_party/flux/configuration/configure.py @@ -108,7 +108,7 @@ def configure_lycoris(): # Prompt user to select an algorithm algo = prompt_user( - f"Which LyCORIS algorithm would you like to use? (Enter the number corresponding to the algorithm)", + "Which LyCORIS algorithm would you like to use? (Enter the number corresponding to the algorithm)", "3", # Default to LoKr ) @@ -339,7 +339,7 @@ def configure_env(): whoami = None try: whoami = huggingface_hub.whoami() - except: + except Exception: pass should_retry = True while not whoami and should_retry: @@ -484,7 +484,7 @@ def configure_env(): model_info = huggingface_hub.model_info(model_name) if hasattr(model_info, "id"): can_load_model = True - except: + except Exception: continue env_contents["--model_type"] = model_type env_contents["--pretrained_model_name_or_path"] = model_name diff --git a/videotuna/third_party/flux/configuration/env_file.py b/videotuna/third_party/flux/configuration/env_file.py index 19b1d057..8e5f6d45 100644 --- a/videotuna/third_party/flux/configuration/env_file.py +++ b/videotuna/third_party/flux/configuration/env_file.py @@ -1,4 +1,6 @@ import json +import logging +import os env_to_args_map = { "RESUME_CHECKPOINT": "--resume_from_checkpoint", @@ -68,9 +70,6 @@ "DISABLE_BENCHMARK": "--disable_benchmark", } -import logging -import os -import subprocess logger = logging.getLogger("SimpleTuner") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) diff --git a/videotuna/third_party/flux/convert_parquet_to_images.py b/videotuna/third_party/flux/convert_parquet_to_images.py index 5ce2d0a1..dc598e5e 100644 --- a/videotuna/third_party/flux/convert_parquet_to_images.py +++ b/videotuna/third_party/flux/convert_parquet_to_images.py @@ -1,14 +1,12 @@ import io import os -import numpy as np import pandas as pd from PIL import Image # Step 1: Load Parquet File parquet_file_path = "data/train-00000-of-00001-dfb0d9df7ebab67e.parquet" # Replace with your Parquet file path output_directory = "data-res" # Directory to save the images -import pandas as pd # Load the Parquet file into a DataFrame df = pd.read_parquet(parquet_file_path) diff --git a/videotuna/third_party/flux/data_backend/aws.py b/videotuna/third_party/flux/data_backend/aws.py index b8b2a357..f28208b2 100644 --- a/videotuna/third_party/flux/data_backend/aws.py +++ b/videotuna/third_party/flux/data_backend/aws.py @@ -1,5 +1,4 @@ import concurrent.futures -import fnmatch import logging import os import time @@ -14,7 +13,6 @@ from videotuna.third_party.flux.data_backend.base import BaseDataBackend from videotuna.third_party.flux.image_manipulation.load import load_image -from videotuna.third_party.flux.training.multi_process import _get_rank as get_rank loggers_to_silence = [ "botocore.hooks", @@ -112,7 +110,7 @@ def exists(self, s3_key): else: # Sleep for a bit before retrying. time.sleep(self.read_retry_interval) - except: + except Exception: if i == self.read_retry_limit - 1: # We have reached our maximum retry count. raise @@ -143,7 +141,7 @@ def read(self, s3_key): else: # Sleep for a bit before retrying. time.sleep(self.read_retry_interval) - except: + except Exception: if i == self.read_retry_limit - 1: # We have reached our maximum retry count. raise @@ -160,7 +158,7 @@ def write(self, s3_key, data): real_key = str(s3_key) for i in range(self.write_retry_limit): try: - if type(data) == Tensor: + if isinstance(data, Tensor): return self.torch_save(data, real_key) response = self.client.put_object( Body=data, @@ -319,7 +317,7 @@ def _detect_file_format(self, fileobj): return "incorrect" else: return "correct_uncompressed" - except Exception as e: + except Exception: # If torch.load fails, it's possibly compressed correctly return "correct_compressed" elif magic_number[:2] == b"\x1f\x8b": diff --git a/videotuna/third_party/flux/data_backend/csv_url_list.py b/videotuna/third_party/flux/data_backend/csv_url_list.py index 790b076f..20c86586 100644 --- a/videotuna/third_party/flux/data_backend/csv_url_list.py +++ b/videotuna/third_party/flux/data_backend/csv_url_list.py @@ -4,7 +4,7 @@ import os from io import BytesIO from pathlib import Path -from typing import Any, BinaryIO, Optional, Union +from typing import Any, Optional, Union import pandas as pd import requests @@ -235,7 +235,7 @@ def read_image_batch( self, filepaths: list, delete_problematic_images: bool = False ) -> list: """Read a batch of images from the specified filepaths.""" - if type(filepaths) != list: + if not isinstance(filepaths, list): raise ValueError( f"read_image_batch must be given a list of image filepaths. we received: {filepaths}" ) diff --git a/videotuna/third_party/flux/data_backend/factory.py b/videotuna/third_party/flux/data_backend/factory.py index 15fe437d..b1e6bd98 100644 --- a/videotuna/third_party/flux/data_backend/factory.py +++ b/videotuna/third_party/flux/data_backend/factory.py @@ -5,7 +5,6 @@ import queue import threading import time -from math import sqrt import torch from tqdm import tqdm @@ -20,10 +19,7 @@ from videotuna.third_party.flux.multiaspect.sampler import MultiAspectSampler from videotuna.third_party.flux.prompts import PromptHandler from videotuna.third_party.flux.training.collate import collate_fn -from videotuna.third_party.flux.training.default_settings import ( - default, - latest_config_version, -) +from videotuna.third_party.flux.training.default_settings import latest_config_version from videotuna.third_party.flux.training.exceptions import MultiDatasetExhausted from videotuna.third_party.flux.training.multi_process import _get_rank as get_rank from videotuna.third_party.flux.training.multi_process import rank_info, should_log diff --git a/videotuna/third_party/flux/data_backend/local.py b/videotuna/third_party/flux/data_backend/local.py index b1baaef5..cebb083c 100644 --- a/videotuna/third_party/flux/data_backend/local.py +++ b/videotuna/third_party/flux/data_backend/local.py @@ -5,7 +5,6 @@ from typing import Any import torch -from regex import regex from videotuna.third_party.flux.data_backend.base import BaseDataBackend from videotuna.third_party.flux.image_manipulation.load import load_image @@ -151,7 +150,7 @@ def read_image_batch( self, filepaths: list, delete_problematic_images: bool = False ) -> list: """Read a batch of images from the specified filepaths.""" - if type(filepaths) != list: + if not isinstance(filepaths, list): raise ValueError( f"read_image_batch must be given a list of image filepaths. we received: {filepaths}" ) @@ -195,7 +194,7 @@ def torch_load(self, filename): if self.compress_cache: try: stored_tensor = self._decompress_torch(stored_tensor) - except Exception as e: + except Exception: pass if hasattr(stored_tensor, "seek"): diff --git a/videotuna/third_party/flux/log_format.py b/videotuna/third_party/flux/log_format.py index b15eddbf..c3c5264c 100644 --- a/videotuna/third_party/flux/log_format.py +++ b/videotuna/third_party/flux/log_format.py @@ -1,5 +1,6 @@ import logging import os +import warnings from colorama import Back, Fore, Style, init @@ -76,7 +77,6 @@ def format(self, record): torch_utils_logger = logging.getLogger("diffusers.utils.torch_utils") torch_utils_logger.setLevel("ERROR") -import warnings # Suppress specific PIL warning warnings.filterwarnings( diff --git a/videotuna/third_party/flux/metadata/backends/base.py b/videotuna/third_party/flux/metadata/backends/base.py index 0231c320..c24b4556 100644 --- a/videotuna/third_party/flux/metadata/backends/base.py +++ b/videotuna/third_party/flux/metadata/backends/base.py @@ -779,7 +779,7 @@ def scan_for_metadata(self): logger.debug( f"Received type of metadata update: {type(metadata_update)}, contents: {metadata_update}" ) - if type(metadata_update) == dict: + if isinstance(metadata_update, dict): for filepath, meta in metadata_update.items(): self.set_metadata_by_filepath( filepath=filepath, metadata=meta, update_json=False diff --git a/videotuna/third_party/flux/metadata/backends/parquet.py b/videotuna/third_party/flux/metadata/backends/parquet.py index 6a64e61a..89277c61 100644 --- a/videotuna/third_party/flux/metadata/backends/parquet.py +++ b/videotuna/third_party/flux/metadata/backends/parquet.py @@ -147,7 +147,7 @@ def _extract_captions_to_fast_list(self): if not identifier_includes_extension: filename = os.path.splitext(filename)[0] - if type(caption_column) == list: + if isinstance(caption_column, list): caption = None if len(caption_column) > 0: caption = [row[c] for c in caption_column] @@ -160,9 +160,9 @@ def _extract_captions_to_fast_list(self): raise ValueError( f"Could not locate caption for image {filename} in sampler_backend {self.id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(self.parquet_database)} entries." ) - if type(caption) == bytes: + if isinstance(caption, bytes): caption = caption.decode("utf-8") - elif type(caption) == list: + elif isinstance(caption, list): caption = [c.strip() for c in caption if c.strip()] if caption: caption = caption.strip() @@ -387,7 +387,6 @@ def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = Fals # Now, pull metadata updates from the queue if len(metadata_updates) > 0 and file in metadata_updates: - metadata_update = metadata_updates[file] self.set_metadata_by_filepath( filepath=file, metadata=metadata_updates[file], update_json=False ) @@ -422,7 +421,7 @@ def _get_first_value(self, series_or_scalar): return series_or_scalar elif isinstance(series_or_scalar, numpy.int64): new_type = int(series_or_scalar) - if type(new_type) != int: + if not isinstance(new_type, int): raise ValueError(f"Unsupported data type: {type(series_or_scalar)}.") return new_type else: @@ -516,7 +515,7 @@ def _process_for_bucket( ) try: self.data_backend.delete(image_path_str) - except: + except Exception: pass statistics.setdefault("skipped", {}).setdefault("too_small", 0) statistics["skipped"]["too_small"] += 1 diff --git a/videotuna/third_party/flux/models/flux/__init__.py b/videotuna/third_party/flux/models/flux/__init__.py index fec70e30..ee3fd1e0 100644 --- a/videotuna/third_party/flux/models/flux/__init__.py +++ b/videotuna/third_party/flux/models/flux/__init__.py @@ -6,7 +6,6 @@ calculate_shift as calculate_shift_flux, ) -from videotuna.third_party.flux.models.flux.pipeline import FluxPipeline from videotuna.third_party.flux.training import steps_remaining_in_epoch diff --git a/videotuna/third_party/flux/models/flux/attention.py b/videotuna/third_party/flux/models/flux/attention.py index 6f69855f..e545c40b 100644 --- a/videotuna/third_party/flux/models/flux/attention.py +++ b/videotuna/third_party/flux/models/flux/attention.py @@ -1,4 +1,4 @@ -from diffusers.models.attention_processor import Attention +import torch from diffusers.models.embeddings import apply_rotary_emb from einops import rearrange from torch import FloatTensor, Tensor @@ -6,7 +6,7 @@ try: from flash_attn_interface import flash_attn_func -except: +except Exception: pass diff --git a/videotuna/third_party/flux/models/flux/pipeline.py b/videotuna/third_party/flux/models/flux/pipeline.py index 3f2fdcc7..8a962cfc 100644 --- a/videotuna/third_party/flux/models/flux/pipeline.py +++ b/videotuna/third_party/flux/models/flux/pipeline.py @@ -13,9 +13,11 @@ # limitations under the License. import inspect +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np +import PIL.Image import torch from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin @@ -25,6 +27,7 @@ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, + BaseOutput, is_torch_xla_available, logging, replace_example_docstring, @@ -915,13 +918,6 @@ def __call__( return FluxPipelineOutput(images=image) -from dataclasses import dataclass -from typing import List, Union - -import PIL.Image -from diffusers.utils import BaseOutput - - @dataclass class FluxPipelineOutput(BaseOutput): """ diff --git a/videotuna/third_party/flux/models/flux/transformer.py b/videotuna/third_party/flux/models/flux/transformer.py index c677fa7d..b0ff6d69 100644 --- a/videotuna/third_party/flux/models/flux/transformer.py +++ b/videotuna/third_party/flux/models/flux/transformer.py @@ -3,7 +3,7 @@ # Originally licensed under the Apache License, Version 2.0 (the "License"); # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -33,21 +33,20 @@ ) from diffusers.utils.torch_utils import maybe_allow_in_graph +from videotuna.third_party.flux.models.flux.attention import ( + FluxAttnProcessor3_0, + FluxSingleAttnProcessor3_0, +) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name is_flash_attn_available = False try: - from flash_attn_interface import flash_attn_func is_flash_attn_available = True -except: +except Exception: pass -from videotuna.third_party.flux.models.flux.attention import ( - FluxAttnProcessor3_0, - FluxSingleAttnProcessor3_0, -) - class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/videotuna/third_party/flux/models/sd3/pipeline.py b/videotuna/third_party/flux/models/sd3/pipeline.py index 90753b48..50a856e9 100644 --- a/videotuna/third_party/flux/models/sd3/pipeline.py +++ b/videotuna/third_party/flux/models/sd3/pipeline.py @@ -15,8 +15,9 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import PIL.Image import torch -from diffusers.image_processor import VaeImageProcessor +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import SD3Transformer2DModel @@ -1012,17 +1013,6 @@ def __call__( # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Union - -import PIL.Image -import torch -from diffusers.image_processor import PipelineImageInput -from transformers import ( - CLIPTextModelWithProjection, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) if is_torch_xla_available(): import torch_xla.core.xla_model as xm diff --git a/videotuna/third_party/flux/models/smoldit/__init__.py b/videotuna/third_party/flux/models/smoldit/__init__.py index c5855473..7fb1e5b8 100644 --- a/videotuna/third_party/flux/models/smoldit/__init__.py +++ b/videotuna/third_party/flux/models/smoldit/__init__.py @@ -1,6 +1,3 @@ -from videotuna.third_party.flux.models.smoldit.pipeline import SmolDiTPipeline -from videotuna.third_party.flux.models.smoldit.transformer import SmolDiT2DModel - SmolDiTConfigurations = { "smoldit-small": { "sample_size": 64, diff --git a/videotuna/third_party/flux/multiaspect/image.py b/videotuna/third_party/flux/multiaspect/image.py index dcef576b..7da15adb 100644 --- a/videotuna/third_party/flux/multiaspect/image.py +++ b/videotuna/third_party/flux/multiaspect/image.py @@ -58,10 +58,10 @@ def is_image_too_large(image_size: tuple, resolution: float, resolution_type: st def calculate_new_size_by_pixel_edge( aspect_ratio: float, resolution: int, original_size: tuple ): - if type(aspect_ratio) != float: + if not isinstance(aspect_ratio, float): raise ValueError(f"Aspect ratio must be a float, not {type(aspect_ratio)}") - if type(resolution) != int and ( - type(resolution) != float or int(resolution) != resolution + if not isinstance(resolution, int) and ( + not isinstance(resolution, float) or int(resolution) != resolution ): raise ValueError(f"Resolution must be an int, not {type(resolution)}") diff --git a/videotuna/third_party/flux/multiaspect/sampler.py b/videotuna/third_party/flux/multiaspect/sampler.py index dcf857c5..66d25c67 100644 --- a/videotuna/third_party/flux/multiaspect/sampler.py +++ b/videotuna/third_party/flux/multiaspect/sampler.py @@ -172,7 +172,7 @@ def retrieve_validation_set(self, batch_size: int): prepend_instance_prompt=self.prepend_instance_prompt, instance_prompt=self.instance_prompt, ) - if type(validation_prompt) == list: + if isinstance(validation_prompt, list): validation_prompt = random.choice(validation_prompt) self.debug_log( f"Selecting random prompt from list: {validation_prompt}" @@ -442,7 +442,7 @@ def _validate_and_yield_images_from_samples(self, samples, bucket): prepend_instance_prompt=self.prepend_instance_prompt, instance_prompt=self.instance_prompt, ) - if type(instance_prompt) == list: + if isinstance(instance_prompt, list): instance_prompt = random.choice(instance_prompt) self.debug_log(f"Selecting random prompt from list: {instance_prompt}") image_metadata["instance_prompt_text"] = instance_prompt diff --git a/videotuna/third_party/flux/prompts.py b/videotuna/third_party/flux/prompts.py index 054a69c4..72595926 100644 --- a/videotuna/third_party/flux/prompts.py +++ b/videotuna/third_party/flux/prompts.py @@ -1,8 +1,12 @@ import json +import logging +import os from pathlib import Path import regex as re +from tqdm import tqdm +from videotuna.third_party.flux.data_backend.base import BaseDataBackend from videotuna.third_party.flux.training import image_file_extensions from videotuna.third_party.flux.training.multi_process import _get_rank as get_rank from videotuna.third_party.flux.training.state_tracker import StateTracker @@ -85,13 +89,6 @@ def prompt_library_injection(new_prompts: dict) -> dict: return {**prompts, **new_prompts} -import logging -import os - -from tqdm import tqdm - -from videotuna.third_party.flux.data_backend.base import BaseDataBackend - logger = logging.getLogger("PromptHandler") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) @@ -258,12 +255,12 @@ def prepare_instance_prompt_from_parquet( raise ValueError( f"Could not locate caption for image {image_path} in sampler_backend {sampler_backend_id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(parquet_db)} entries." ) - if type(image_caption) == bytes: + if isinstance(image_caption, bytes): image_caption = image_caption.decode("utf-8") if image_caption: image_caption = image_caption.strip() if prepend_instance_prompt: - if type(image_caption) == list: + if isinstance(image_caption, list): image_caption = [instance_prompt + " " + x for x in image_caption] else: image_caption = instance_prompt + " " + image_caption @@ -309,7 +306,7 @@ def prepare_instance_prompt_from_textfile( try: image_caption = data_backend.read(caption_file) # Convert from bytes to str: - if type(image_caption) == bytes: + if isinstance(image_caption, bytes): image_caption = image_caption.decode("utf-8") # any newlines? split into array @@ -407,7 +404,7 @@ def get_all_captions( backend_config = StateTracker.get_data_backend_config( data_backend_id=data_backend.id ) - if type(all_image_files) == list and type(all_image_files[0]) == tuple: + if isinstance(all_image_files, list) and isinstance(all_image_files[0], tuple): all_image_files = all_image_files[0][2] from tqdm import tqdm @@ -449,7 +446,7 @@ def get_all_captions( data_backend=data_backend, sampler_backend_id=data_backend.id, ) - except: + except Exception: continue elif caption_strategy == "instanceprompt": return [instance_prompt] @@ -506,7 +503,7 @@ def filter_captions(data_backend: BaseDataBackend, captions: list) -> list: if not caption_filter_list or caption_filter_list == "": return captions if ( - type(caption_filter_list) == str + isinstance(caption_filter_list, str) and os.path.splitext(caption_filter_list)[1] == ".json" ): # It's a path to a filter list. Load it in JSON format. @@ -520,7 +517,7 @@ def filter_captions(data_backend: BaseDataBackend, captions: list) -> list: ) raise e elif ( - type(caption_filter_list) == str + isinstance(caption_filter_list, str) and os.path.splitext(caption_filter_list)[1] == ".txt" ): # We have a plain text list of filter strings/regex. Load them into an array: @@ -536,7 +533,7 @@ def filter_captions(data_backend: BaseDataBackend, captions: list) -> list: ) raise e # We have the filter list. Is it valid and non-empty? - if type(caption_filter_list) != list or len(caption_filter_list) == 0: + if not isinstance(caption_filter_list, list) or len(caption_filter_list) == 0: logger.debug( f"Data backend '{data_backend.id}' has an invalid or empty caption filter list." ) @@ -585,7 +582,7 @@ def filter_captions(data_backend: BaseDataBackend, captions: list) -> list: pattern = re.compile(filter_item) try: regex_modified_caption = pattern.sub("", modified_caption) - except: + except Exception: regex_modified_caption = modified_caption if regex_modified_caption != modified_caption: # logger.debug( diff --git a/videotuna/third_party/flux/publishing/huggingface.py b/videotuna/third_party/flux/publishing/huggingface.py index 8c49b145..6c7eb5d2 100644 --- a/videotuna/third_party/flux/publishing/huggingface.py +++ b/videotuna/third_party/flux/publishing/huggingface.py @@ -100,7 +100,7 @@ def upload_model(self, validation_images, webhook_handler=None, override_path=No self.upload_validation_folder( webhook_handler=webhook_handler, override_path=override_path ) - except: + except Exception: logger.error("Error uploading validation images to Hugging Face Hub.") attempt = 0 diff --git a/videotuna/third_party/flux/publishing/metadata.py b/videotuna/third_party/flux/publishing/metadata.py index b966eda6..10dab105 100644 --- a/videotuna/third_party/flux/publishing/metadata.py +++ b/videotuna/third_party/flux/publishing/metadata.py @@ -131,7 +131,7 @@ def _model_load(args, repo_id: str = None): f"\npipeline = DiffusionPipeline.from_pretrained(model_id)" ) if args.model_type == "lora" and args.lora_type.lower() == "lycoris": - output += f"\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipeline.transformer)" + output += "\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipeline.transformer)" output += "\nwrapper.merge_to()" return output @@ -157,7 +157,7 @@ def _guidance_rescale(args): def _validation_resolution(args): if args.validation_resolution == "" or args.validation_resolution is None: - return f"width=1024,\n" f" height=1024," + return "width=1024,\n" " height=1024," resolutions = [args.validation_resolution] if "," in args.validation_resolution: # split the resolution into a list of resolutions diff --git a/videotuna/third_party/flux/training/__init__.py b/videotuna/third_party/flux/training/__init__.py index e0a8b83f..d212063b 100644 --- a/videotuna/third_party/flux/training/__init__.py +++ b/videotuna/third_party/flux/training/__init__.py @@ -1,3 +1,5 @@ +import torch + quantised_precision_levels = [ "no_change", "int8-quanto", @@ -5,7 +7,6 @@ "int2-quanto", "int8-torchao", ] -import torch if torch.cuda.is_available(): quantised_precision_levels.extend( diff --git a/videotuna/third_party/flux/training/collate.py b/videotuna/third_party/flux/training/collate.py index 26943eaa..8fef5ac2 100644 --- a/videotuna/third_party/flux/training/collate.py +++ b/videotuna/third_party/flux/training/collate.py @@ -5,6 +5,7 @@ import numpy as np import torch +from torchvision.transforms import ToTensor from videotuna.third_party.flux.image_manipulation.training_sample import TrainingSample from videotuna.third_party.flux.training.multi_process import rank_info @@ -13,7 +14,6 @@ logger = logging.getLogger("collate_fn") logger.setLevel(environ.get("SIMPLETUNER_COLLATE_LOG_LEVEL", "INFO")) rank_text = rank_info() -from torchvision.transforms import ToTensor # Convert PIL Image to PyTorch Tensor to_tensor = ToTensor() @@ -252,7 +252,7 @@ def compute_single_embedding( prompt_embeds = text_embed_cache.compute_embeddings_for_legacy_prompts( [caption] ) - if type(prompt_embeds) == tuple: + if isinstance(prompt_embeds, tuple): if StateTracker.get_model_family() in ["pixart_sigma", "smoldit"]: # PixArt requires the attn mask be returned, too. prompt_embeds, attn_mask = prompt_embeds diff --git a/videotuna/third_party/flux/training/custom_schedule.py b/videotuna/third_party/flux/training/custom_schedule.py index 215ffd61..2727d44f 100644 --- a/videotuna/third_party/flux/training/custom_schedule.py +++ b/videotuna/third_party/flux/training/custom_schedule.py @@ -1,9 +1,15 @@ import logging import math import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union import accelerate import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput from torch.optim.lr_scheduler import LambdaLR, LRScheduler from videotuna.third_party.flux.training.state_tracker import StateTracker @@ -486,9 +492,6 @@ def print_lr(self, is_verbose, group, lr, epoch=None): ) -from diffusers.optimization import get_scheduler - - def get_lr_scheduler( args, optimizer, accelerator, logger, use_deepspeed_scheduler=False ): @@ -586,14 +589,6 @@ def get_lr_scheduler( # DISCLAIMER: This code is strongly influenced by https://github.com/leffff/euler-scheduler -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.utils import BaseOutput - @dataclass class FlowMatchingEulerSchedulerOutput(BaseOutput): diff --git a/videotuna/third_party/flux/training/default_settings/safety_check.py b/videotuna/third_party/flux/training/default_settings/safety_check.py index c86ce69c..ec4540a4 100644 --- a/videotuna/third_party/flux/training/default_settings/safety_check.py +++ b/videotuna/third_party/flux/training/default_settings/safety_check.py @@ -5,17 +5,16 @@ from diffusers.utils import is_wandb_available +from videotuna.third_party.flux.training.error_handling import ( + validate_deepspeed_compat_from_args, +) from videotuna.third_party.flux.training.multi_process import _get_rank as get_rank -from videotuna.third_party.flux.training.state_tracker import StateTracker logger = logging.getLogger(__name__) if get_rank() == 0: logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) else: logger.setLevel(logging.ERROR) -from videotuna.third_party.flux.training.error_handling import ( - validate_deepspeed_compat_from_args, -) def safety_check(args, accelerator): @@ -48,7 +47,6 @@ def safety_check(args, accelerator): raise ImportError( "Make sure to install wandb if you want to use it for logging during training." ) - import wandb if accelerator is not None and ( hasattr(accelerator.state, "deepspeed_plugin") and accelerator.state.deepspeed_plugin is not None @@ -120,6 +118,6 @@ def safety_check(args, accelerator): and args.flux_schedule_auto_shift ): logger.error( - f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." + "--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) sys.exit(1) diff --git a/videotuna/third_party/flux/training/diffusion_model.py b/videotuna/third_party/flux/training/diffusion_model.py index d928aa28..ecbaf3ff 100644 --- a/videotuna/third_party/flux/training/diffusion_model.py +++ b/videotuna/third_party/flux/training/diffusion_model.py @@ -66,7 +66,6 @@ def load_diffusion_model(args, weight_dtype): if primary_device.major >= 9: try: import diffusers - from flash_attn_interface import flash_attn_func from videotuna.third_party.flux.models.flux.attention import ( FluxAttnProcessor3_0, @@ -81,7 +80,7 @@ def load_diffusion_model(args, weight_dtype): ) if rank == 0: print("Using FlashAttention3_0 for H100 GPU (Single block)") - except: + except Exception: if rank == 0: logger.warning( "No flash_attn is available, using slower FlashAttention_2_0. Install flash_attn to make use of FA3 for Hopper or newer arch." diff --git a/videotuna/third_party/flux/training/model.py b/videotuna/third_party/flux/training/model.py index 430b733b..5f3e0a88 100644 --- a/videotuna/third_party/flux/training/model.py +++ b/videotuna/third_party/flux/training/model.py @@ -1,5 +1,4 @@ import copy -import glob import hashlib import json import logging @@ -9,28 +8,30 @@ import shutil import sys +import accelerate +import diffusers import huggingface_hub import pytorch_lightning as pl -import torch.distributed as dist +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers import wandb -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_only -from safetensors.torch import save_file - -from videotuna.third_party.flux.configuration.configure import model_labels -from videotuna.third_party.flux.publishing.huggingface import HubManager -from videotuna.third_party.flux.training.default_settings.safety_check import ( - safety_check, -) -from videotuna.utils.callbacks import LoraModelCheckpoint +from accelerate import Accelerator # Quiet down, you. -os.environ["ACCELERATE_LOG_LEVEL"] = "WARNING" from accelerate.logging import get_logger +from accelerate.utils import set_seed from diffusers.models.embeddings import get_2d_rotary_pos_embed +from pytorch_lightning.utilities import rank_zero_only +from torch.distributions import Beta from videotuna.third_party.flux import log_format # noqa from videotuna.third_party.flux.caching.memory import reclaim_memory +from videotuna.third_party.flux.configuration.configure import ( + model_classes, + model_labels, +) from videotuna.third_party.flux.configuration.loader import load_config from videotuna.third_party.flux.data_backend.factory import ( BatchFetcher, @@ -38,7 +39,7 @@ random_dataloader_iterator, ) from videotuna.third_party.flux.models.smoldit import get_resize_crop_region_for_grid -from videotuna.third_party.flux.training import steps_remaining_in_epoch +from videotuna.third_party.flux.publishing.huggingface import HubManager from videotuna.third_party.flux.training.adapter import ( determine_adapter_target_modules, load_lora_weights, @@ -52,6 +53,9 @@ deepspeed_zero_init_disabled_context_manager, prepare_model_for_deepspeed, ) +from videotuna.third_party.flux.training.default_settings.safety_check import ( + safety_check, +) from videotuna.third_party.flux.training.diffusion_model import load_diffusion_model from videotuna.third_party.flux.training.min_snr_gamma import compute_snr from videotuna.third_party.flux.training.multi_process import _get_rank as get_rank @@ -77,39 +81,11 @@ prepare_validation_prompt_list, ) from videotuna.third_party.flux.training.wrappers import unwrap_model - -logger = get_logger( - "SimpleTuner", log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") -) - -filelock_logger = get_logger("filelock") -connection_logger = get_logger("urllib3.connectionpool") -training_logger = get_logger("training-loop") - -# More important logs. -target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") -logger.setLevel(target_level) -training_logger_level = os.environ.get("SIMPLETUNER_TRAINING_LOOP_LOG_LEVEL", "INFO") -training_logger.setLevel(training_logger_level) - -# Less important logs. -filelock_logger.setLevel("WARNING") -connection_logger.setLevel("WARNING") -import accelerate -import diffusers -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.utils import set_seed -from torch.distributions import Beta - -from videotuna.third_party.flux.configuration.configure import model_classes +from videotuna.utils.callbacks import LoraModelCheckpoint try: from lycoris import LycorisNetwork -except: +except Exception: print("[ERROR] Lycoris not available. Please install ") from diffusers import ( AutoencoderKL, @@ -120,20 +96,14 @@ EulerDiscreteScheduler, FluxTransformer2DModel, PixArtTransformer2DModel, - StableDiffusion3Pipeline, UNet2DConditionModel, UniPCMultistepScheduler, ) -from diffusers.utils import ( - check_min_version, - convert_state_dict_to_diffusers, - is_wandb_available, -) +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers from diffusers.utils.import_utils import is_xformers_available from peft import LoraConfig from peft.utils import get_peft_model_state_dict from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig from transformers.utils import ContextManagers from videotuna.third_party.flux.models.flux import ( @@ -143,15 +113,33 @@ prepare_latent_image_ids, unpack_latents, ) -from videotuna.third_party.flux.models.sdxl.pipeline import StableDiffusionXLPipeline from videotuna.third_party.flux.training.ema import EMAModel +os.environ["ACCELERATE_LOG_LEVEL"] = "WARNING" +logger = get_logger( + "SimpleTuner", log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") +) + +filelock_logger = get_logger("filelock") +connection_logger = get_logger("urllib3.connectionpool") +training_logger = get_logger("training-loop") + +# More important logs. +target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") +logger.setLevel(target_level) +training_logger_level = os.environ.get("SIMPLETUNER_TRAINING_LOOP_LOG_LEVEL", "INFO") +training_logger.setLevel(training_logger_level) + +# Less important logs. +filelock_logger.setLevel("WARNING") +connection_logger.setLevel("WARNING") + is_optimi_available = False try: from optimi import prepare_for_gradient_release is_optimi_available = True -except: +except Exception: pass # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -483,7 +471,7 @@ def init_vae(self, move_to_accelerator: bool = True): } try: self.vae = AutoencoderKL.from_pretrained(**self.config.vae_kwargs) - except: + except Exception: logger.warning( "Couldn't load VAE with default path. Trying without a subfolder.." ) @@ -763,8 +751,9 @@ def init_unload_text_encoder(self): memory_after_unload = self.stats_memory_used() memory_saved = memory_after_unload - memory_before_unload logger.info( - f"After nuking text encoders from orbit, we freed {abs(round(memory_saved, 2))} GB of VRAM." - " The real memories were the friends we trained a model on along the way." + f"After nuking text encoders from orbit, we freed " + f"{abs(round(memory_saved, 2))} GB of VRAM. The real memories were the " + f"friends we trained a model on along the way." ) def init_precision(self): @@ -788,7 +777,8 @@ def init_precision(self): self.config.enable_adamw_bf16 = True if self.unet is not None: logger.info( - f"Moving U-net to dtype={self.config.base_weight_dtype}, device={quantization_device}" + f"Moving U-net to dtype={self.config.base_weight_dtype}," + f" device={quantization_device}" ) self.unet.to(quantization_device, dtype=self.config.base_weight_dtype) elif self.transformer is not None: @@ -1168,7 +1158,7 @@ def init_lr_scheduler(self): use_deepspeed_scheduler=False, ) else: - logger.info(f"Using dummy learning rate scheduler") + logger.info("Using dummy learning rate scheduler") if torch.backends.mps.is_available(): lr_scheduler = None else: @@ -1487,7 +1477,7 @@ def init_resume_checkpoint(self, lr_scheduler): structured_data={"message": f"Resuming model: {path}"}, message_type="init_resume_checkpoint", ) - training_state_filename = f"training_state.json" + training_state_filename = "training_state.json" if get_rank() > 0: training_state_filename = f"training_state-{get_rank()}.json" for _, backend in StateTracker.get_data_backends().items(): @@ -1793,16 +1783,16 @@ def abort(self): # we should set should_abort = True on each data backend's vae cache, metadata, and text backend for _, backend in StateTracker.get_data_backends().items(): if "vaecache" in backend: - logger.debug(f"Aborting VAE cache") + logger.debug("Aborting VAE cache") backend["vaecache"].should_abort = True if "metadata_backend" in backend: - logger.debug(f"Aborting metadata backend") + logger.debug("Aborting metadata backend") backend["metadata_backend"].should_abort = True if "text_backend" in backend: - logger.debug(f"Aborting text backend") + logger.debug("Aborting text backend") backend["text_backend"].should_abort = True if "sampler" in backend: - logger.debug(f"Aborting sampler") + logger.debug("Aborting sampler") backend["sampler"].should_abort = True self.should_abort = True @@ -2823,7 +2813,6 @@ def on_train_end(self): text_encoder_2_lora_layers = None if self.config.model_family == "flux": - from diffusers.pipelines import FluxPipeline print("saving lora...") self.save_lora() diff --git a/videotuna/third_party/flux/training/model_data.py b/videotuna/third_party/flux/training/model_data.py index a0cb9b12..7d7889b3 100644 --- a/videotuna/third_party/flux/training/model_data.py +++ b/videotuna/third_party/flux/training/model_data.py @@ -1,11 +1,10 @@ -import json import os from pathlib import Path import pytorch_lightning as pl from accelerate.logging import get_logger from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DataLoader from videotuna.third_party.flux.data_backend.factory import configure_multi_databackend from videotuna.third_party.flux.training.state_tracker import StateTracker @@ -14,6 +13,7 @@ "SimpleTuner", log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") ) + def create_txt_labels_from_dir(data_dir, caption): """ Create multiple txt files, each txt file is the content of the caption string. @@ -22,6 +22,7 @@ def create_txt_labels_from_dir(data_dir, caption): with open(os.path.join(data_dir, Path(image).stem) + ".txt", "w") as f: f.write(caption) + class ModelData(pl.LightningDataModule): def __init__( self, diff --git a/videotuna/third_party/flux/training/model_freeze.py b/videotuna/third_party/flux/training/model_freeze.py index 7e565f15..4b882a65 100644 --- a/videotuna/third_party/flux/training/model_freeze.py +++ b/videotuna/third_party/flux/training/model_freeze.py @@ -1,6 +1,5 @@ import logging import os -import re from torch import nn @@ -25,7 +24,7 @@ def freeze_transformer_blocks( f"Invalid freeze_direction value {freeze_direction}. Choose from 'up', 'down'." ) if first_unfrozen_dit_layer < 0 or first_unfrozen_mmdit_layer < 0: - raise ValueError(f"Invalid first_unfrozen layer value. Must be greater than 0.") + raise ValueError("Invalid first_unfrozen layer value. Must be greater than 0.") for name, param in model.named_parameters(): # Example names: # single_transformer_blocks.31.ff.c_proj.weight @@ -33,7 +32,7 @@ def freeze_transformer_blocks( try: layer_group = name.split(".")[0] layer_number = int(name.split(".")[1]) - except Exception as e: + except Exception: # non-numeric layer. continue try: diff --git a/videotuna/third_party/flux/training/multi_process.py b/videotuna/third_party/flux/training/multi_process.py index 3448a675..cf472b64 100644 --- a/videotuna/third_party/flux/training/multi_process.py +++ b/videotuna/third_party/flux/training/multi_process.py @@ -11,7 +11,7 @@ def _get_rank(): def rank_info(): try: return f"(Rank: {_get_rank()}) " - except: + except Exception: return "" diff --git a/videotuna/third_party/flux/training/optimizer_param.py b/videotuna/third_party/flux/training/optimizer_param.py index 32f322e3..6606c5bc 100644 --- a/videotuna/third_party/flux/training/optimizer_param.py +++ b/videotuna/third_party/flux/training/optimizer_param.py @@ -4,12 +4,6 @@ import torch from accelerate.logging import get_logger -logger = get_logger(__name__, log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) - -target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") -logger.setLevel(target_level) - -is_optimi_available = False from videotuna.third_party.flux.training.optimizers.adamw_bfloat16 import AdamWBF16 from videotuna.third_party.flux.training.optimizers.adamw_schedulefree import ( AdamWScheduleFreeKahan, @@ -17,8 +11,8 @@ from videotuna.third_party.flux.training.optimizers.soap import SOAP try: - from optimum.quanto import QTensor -except: + pass +except Exception: pass try: @@ -38,11 +32,13 @@ print("You need torchao installed for its low-precision optimizers.") raise e +logger = get_logger(__name__, log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) + try: import optimi is_optimi_available = True -except: +except Exception: logger.error( "Could not load optimi library. Please install `torch-optimi` for better memory efficiency." ) @@ -52,12 +48,18 @@ import bitsandbytes is_bitsandbytes_available = True -except: +except Exception: if torch.cuda.is_available(): logger.warning( "Could not load bitsandbytes library. BnB-specific optimisers and other functionality will be unavailable." ) + +target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") +logger.setLevel(target_level) + +is_optimi_available = False + optimizer_choices = { "adamw_bf16": { "precision": "bf16", @@ -591,7 +593,7 @@ def determine_optimizer_class_with_config( optimizer_details = {} elif is_quantized and not enable_adamw_bf16: logger.error( - f"When --base_model_default_dtype=fp32, AdamWBF16 may not be used. Switching to AdamW." + "When --base_model_default_dtype=fp32, AdamWBF16 may not be used. Switching to AdamW." ) optimizer_class, optimizer_details = optimizer_parameters("optimi-adamw", args) else: diff --git a/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/__init__.py b/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/__init__.py index 796a2bbc..4078a5d9 100644 --- a/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/__init__.py +++ b/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/__init__.py @@ -3,20 +3,18 @@ they have identical interface, but sutiable for different scenarios. """ -__version__ = "0.2.0" +import torch +from torch.optim.optimizer import Optimizer -__all__ = ["AdamW_BF16"] +from .stochastic import add_stochastic_, addcdiv_stochastic_ + +__version__ = "0.2.0" """ This implementation uses torch.compile to speed up, should be suitable for different backends. """ -import torch -from torch.optim.optimizer import Optimizer - -from .stochastic import add_stochastic_, addcdiv_stochastic_ - class AdamWBF16(Optimizer): decay_threshold = 5e-3 diff --git a/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/stochastic/__init__.py b/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/stochastic/__init__.py index 7829aabf..81cf17b3 100644 --- a/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/stochastic/__init__.py +++ b/videotuna/third_party/flux/training/optimizers/adamw_bfloat16/stochastic/__init__.py @@ -1,16 +1,16 @@ import torch -from torch import FloatTensor, Tensor +from torch import Tensor -def swap_first_and_last_dims(tensor: torch.Tensor) -> torch.Tensor: +def swap_first_and_last_dims(tensor: Tensor) -> Tensor: """ Swap the first dimension with the last dimension of a tensor. Args: - tensor (torch.Tensor): The input tensor of any shape. + tensor (Tensor): The input tensor of any shape. Returns: - torch.Tensor: A tensor with the first dimension swapped with the last. + Tensor: A tensor with the first dimension swapped with the last. """ # Get the total number of dimensions num_dims = len(tensor.shape) @@ -22,16 +22,16 @@ def swap_first_and_last_dims(tensor: torch.Tensor) -> torch.Tensor: return tensor.permute(*new_order) -def swap_back_first_and_last_dims(tensor: torch.Tensor) -> torch.Tensor: +def swap_back_first_and_last_dims(tensor: Tensor) -> Tensor: """ Swap back the first dimension with the last dimension of a tensor to its original shape after a swap. Args: - tensor (torch.Tensor): The tensor that had its first and last dimensions swapped. + tensor (Tensor): The tensor that had its first and last dimensions swapped. Returns: - torch.Tensor: A tensor with its original shape restored. + Tensor: A tensor with its original shape restored. """ # Get the total number of dimensions num_dims = len(tensor.shape) diff --git a/videotuna/third_party/flux/training/optimizers/adamw_schedulefree/__init__.py b/videotuna/third_party/flux/training/optimizers/adamw_schedulefree/__init__.py index 3e64082f..6057d9b2 100644 --- a/videotuna/third_party/flux/training/optimizers/adamw_schedulefree/__init__.py +++ b/videotuna/third_party/flux/training/optimizers/adamw_schedulefree/__init__.py @@ -1,4 +1,3 @@ -import math from typing import Iterable import torch diff --git a/videotuna/third_party/flux/training/optimizers/soap/__init__.py b/videotuna/third_party/flux/training/optimizers/soap/__init__.py index 0d5c3927..6736040f 100644 --- a/videotuna/third_party/flux/training/optimizers/soap/__init__.py +++ b/videotuna/third_party/flux/training/optimizers/soap/__init__.py @@ -1,7 +1,6 @@ from itertools import chain import torch -import torch.nn as nn import torch.optim as optim # Parts of the code are modifications of Pytorch's AdamW optimizer @@ -407,7 +406,7 @@ def get_orthogonal_matrix(self, mat): _, Q = torch.linalg.eigh( m + 1e-30 * torch.eye(m.shape[0], device=m.device) ) - except: + except Exception: _, Q = torch.linalg.eigh( m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device) ) diff --git a/videotuna/third_party/flux/training/quantisation/__init__.py b/videotuna/third_party/flux/training/quantisation/__init__.py index 07356201..95d5d1cd 100644 --- a/videotuna/third_party/flux/training/quantisation/__init__.py +++ b/videotuna/third_party/flux/training/quantisation/__init__.py @@ -48,9 +48,7 @@ def _quanto_model( quantize_activations: bool = False, ): try: - from optimum.quanto import QTensor, freeze, quantize - - from videotuna.third_party.flux.training.quantisation import quanto_workarounds + from optimum.quanto import freeze, quantize except ImportError as e: raise ImportError( f"To use Quanto, please install the optimum library: `pip install optimum-quanto`: {e}" @@ -138,14 +136,12 @@ def _torchao_model( return model try: - import torchao from torchao.float8 import Float8LinearConfig, convert_to_float8_training from torchao.prototype.quantized_training import ( int8_weight_only_quantized_training, ) from torchao.quantization import quantize_ - from videotuna.third_party.flux.training.quantisation import torchao_workarounds except ImportError as e: raise ImportError( f"To use torchao, please install the torchao library: `pip install torchao`: {e}" diff --git a/videotuna/third_party/flux/training/quantisation/quanto_workarounds.py b/videotuna/third_party/flux/training/quantisation/quanto_workarounds.py index c6b286d4..86a69202 100644 --- a/videotuna/third_party/flux/training/quantisation/quanto_workarounds.py +++ b/videotuna/third_party/flux/training/quantisation/quanto_workarounds.py @@ -4,7 +4,6 @@ if torch.cuda.is_available(): # the marlin fp8 kernel needs some help with dtype casting for some reason # see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201 - from optimum.quanto.library.extensions.cuda import ext as quanto_ext # Save the original operator original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin diff --git a/videotuna/third_party/flux/training/state_tracker.py b/videotuna/third_party/flux/training/state_tracker.py index a72910fd..fd3b08d5 100644 --- a/videotuna/third_party/flux/training/state_tracker.py +++ b/videotuna/third_party/flux/training/state_tracker.py @@ -80,7 +80,7 @@ def delete_cache_files( if cache_path.exists(): try: cache_path.unlink() - except: + except Exception: pass @classmethod diff --git a/videotuna/third_party/flux/training/text_encoding.py b/videotuna/third_party/flux/training/text_encoding.py index 31b7c943..e65132ed 100644 --- a/videotuna/third_party/flux/training/text_encoding.py +++ b/videotuna/third_party/flux/training/text_encoding.py @@ -164,7 +164,7 @@ def get_tokenizers(args): revision=args.revision, use_fast=True, ) - except: + except Exception: raise ValueError( "Could not load tertiary tokenizer (T5-XXL v1.1). Cannot continue." ) diff --git a/videotuna/third_party/flux/training/trainer.py b/videotuna/third_party/flux/training/trainer.py index de0128b4..ac9dbc7c 100644 --- a/videotuna/third_party/flux/training/trainer.py +++ b/videotuna/third_party/flux/training/trainer.py @@ -1,5 +1,4 @@ import copy -import glob import hashlib import json import logging @@ -21,8 +20,17 @@ # Quiet down, you. os.environ["ACCELERATE_LOG_LEVEL"] = "WARNING" +import accelerate +import diffusers +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.utils import set_seed +from configure import model_classes from diffusers.models.embeddings import get_2d_rotary_pos_embed +from torch.distributions import Beta from videotuna.third_party.flux import log_format # noqa from videotuna.third_party.flux.caching.memory import reclaim_memory @@ -33,7 +41,6 @@ random_dataloader_iterator, ) from videotuna.third_party.flux.models.smoldit import get_resize_crop_region_for_grid -from videotuna.third_party.flux.training import steps_remaining_in_epoch from videotuna.third_party.flux.training.adapter import ( determine_adapter_target_modules, load_lora_weights, @@ -73,37 +80,9 @@ ) from videotuna.third_party.flux.training.wrappers import unwrap_model -logger = get_logger( - "SimpleTuner", log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") -) - -filelock_logger = get_logger("filelock") -connection_logger = get_logger("urllib3.connectionpool") -training_logger = get_logger("training-loop") - -# More important logs. -target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") -logger.setLevel(target_level) -training_logger_level = os.environ.get("SIMPLETUNER_TRAINING_LOOP_LOG_LEVEL", "INFO") -training_logger.setLevel(training_logger_level) - -# Less important logs. -filelock_logger.setLevel("WARNING") -connection_logger.setLevel("WARNING") -import accelerate -import diffusers -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.utils import set_seed -from configure import model_classes -from torch.distributions import Beta - try: from lycoris import LycorisNetwork -except: +except Exception: print("[ERROR] Lycoris not available. Please install ") from diffusers import ( AutoencoderKL, @@ -118,16 +97,11 @@ UNet2DConditionModel, UniPCMultistepScheduler, ) -from diffusers.utils import ( - check_min_version, - convert_state_dict_to_diffusers, - is_wandb_available, -) +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers from diffusers.utils.import_utils import is_xformers_available from peft import LoraConfig from peft.utils import get_peft_model_state_dict from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig from transformers.utils import ContextManagers from videotuna.third_party.flux.models.flux import ( @@ -140,12 +114,30 @@ from videotuna.third_party.flux.models.sdxl.pipeline import StableDiffusionXLPipeline from videotuna.third_party.flux.training.ema import EMAModel +logger = get_logger( + "SimpleTuner", log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") +) + +filelock_logger = get_logger("filelock") +connection_logger = get_logger("urllib3.connectionpool") +training_logger = get_logger("training-loop") + +# More important logs. +target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") +logger.setLevel(target_level) +training_logger_level = os.environ.get("SIMPLETUNER_TRAINING_LOOP_LOG_LEVEL", "INFO") +training_logger.setLevel(training_logger_level) + +# Less important logs. +filelock_logger.setLevel("WARNING") +connection_logger.setLevel("WARNING") + is_optimi_available = False try: from optimi import prepare_for_gradient_release is_optimi_available = True -except: +except Exception: pass # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -476,7 +468,7 @@ def init_vae(self, move_to_accelerator: bool = True): } try: self.vae = AutoencoderKL.from_pretrained(**self.config.vae_kwargs) - except: + except Exception: logger.warning( "Couldn't load VAE with default path. Trying without a subfolder.." ) @@ -1161,7 +1153,7 @@ def init_lr_scheduler(self): use_deepspeed_scheduler=False, ) else: - logger.info(f"Using dummy learning rate scheduler") + logger.info("Using dummy learning rate scheduler") if torch.backends.mps.is_available(): lr_scheduler = None else: @@ -1480,7 +1472,7 @@ def init_resume_checkpoint(self, lr_scheduler): structured_data={"message": f"Resuming model: {path}"}, message_type="init_resume_checkpoint", ) - training_state_filename = f"training_state.json" + training_state_filename = "training_state.json" if get_rank() > 0: training_state_filename = f"training_state-{get_rank()}.json" for _, backend in StateTracker.get_data_backends().items(): @@ -1786,16 +1778,16 @@ def abort(self): # we should set should_abort = True on each data backend's vae cache, metadata, and text backend for _, backend in StateTracker.get_data_backends().items(): if "vaecache" in backend: - logger.debug(f"Aborting VAE cache") + logger.debug("Aborting VAE cache") backend["vaecache"].should_abort = True if "metadata_backend" in backend: - logger.debug(f"Aborting metadata backend") + logger.debug("Aborting metadata backend") backend["metadata_backend"].should_abort = True if "text_backend" in backend: - logger.debug(f"Aborting text backend") + logger.debug("Aborting text backend") backend["text_backend"].should_abort = True if "sampler" in backend: - logger.debug(f"Aborting sampler") + logger.debug("Aborting sampler") backend["sampler"].should_abort = True self.should_abort = True diff --git a/videotuna/third_party/flux/training/validation.py b/videotuna/third_party/flux/training/validation.py index 03cc84b6..b6e991a8 100644 --- a/videotuna/third_party/flux/training/validation.py +++ b/videotuna/third_party/flux/training/validation.py @@ -1,19 +1,21 @@ import logging import os +import time import numpy as np import torch import wandb +from diffusers import AutoencoderKL, DDIMScheduler # from toolsegacy.pipeline import StableDiffusionPipeline from diffusers.schedulers import ( - DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, ) +from diffusers.utils import is_wandb_available from diffusers.utils.torch_utils import is_compiled_module from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm @@ -24,6 +26,7 @@ StableDiffusionXLPipeline, ) from videotuna.third_party.flux.multiaspect.image import MultiaspectImage +from videotuna.third_party.flux.prompts import PromptHandler from videotuna.third_party.flux.training.state_tracker import StateTracker from videotuna.third_party.flux.training.wrappers import unwrap_model @@ -50,14 +53,6 @@ "ddpm": DDPMScheduler, } -import logging -import os -import time - -from diffusers import AutoencoderKL, DDIMScheduler -from diffusers.utils import is_wandb_available - -from videotuna.third_party.flux.prompts import PromptHandler if is_wandb_available(): import wandb @@ -141,8 +136,6 @@ def retrieve_validation_images(): def prepare_validation_prompt_list(args, embed_cache): - validation_negative_prompt_embeds = None - validation_negative_pooled_embeds = None validation_prompts = ( [""] if not StateTracker.get_args().validation_disable_unconditional else [] ) diff --git a/videotuna/third_party/flux/webhooks/handler.py b/videotuna/third_party/flux/webhooks/handler.py index 85024568..6b64ddb7 100644 --- a/videotuna/third_party/flux/webhooks/handler.py +++ b/videotuna/third_party/flux/webhooks/handler.py @@ -1,4 +1,3 @@ -import json import logging import os import time diff --git a/videotuna/utils/callbacks.py b/videotuna/utils/callbacks.py index 2418f8bd..ac403fc2 100755 --- a/videotuna/utils/callbacks.py +++ b/videotuna/utils/callbacks.py @@ -1,15 +1,7 @@ -import datetime import logging import os import time -import numpy as np -from einops import rearrange -from omegaconf import OmegaConf -from PIL import Image - -mainlogger = logging.getLogger("mainlogger") - import pytorch_lightning as pl import torch import torchvision @@ -18,6 +10,8 @@ from .save_video import log_local, prepare_to_log +mainlogger = logging.getLogger("mainlogger") + class LoraModelCheckpoint(pl.callbacks.ModelCheckpoint): def __init__(self, *args, **kwargs): diff --git a/videotuna/utils/common_utils.py b/videotuna/utils/common_utils.py index 0ef55e26..35178514 100644 --- a/videotuna/utils/common_utils.py +++ b/videotuna/utils/common_utils.py @@ -27,7 +27,7 @@ def check_istarget(name, para_list): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": diff --git a/videotuna/utils/inference_utils.py b/videotuna/utils/inference_utils.py index de156ca1..fde0f687 100644 --- a/videotuna/utils/inference_utils.py +++ b/videotuna/utils/inference_utils.py @@ -1,6 +1,4 @@ -import glob import os -import sys from collections import OrderedDict import cv2 @@ -53,14 +51,14 @@ def load_checkpoint(model, ckpt, full_strict): for key in state_dict["module"].keys(): new_pl_sd[key[16:]] = state_dict["module"][key] model.load_state_dict(new_pl_sd, strict=full_strict) - except: + except Exception: if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] try: model.model.diffusion_model.load_state_dict( state_dict, strict=full_strict ) - except: + except Exception: model.load_state_dict(state_dict, strict=False) return model @@ -141,9 +139,8 @@ def load_inputs_v2v(input_dir, video_size=None, video_frames=None): print(prompt_files) raise ValueError(f"Error: found NO prompt file in {input_dir}") prompt_list = load_prompts_from_txt(prompt_file) - n_samples = len(prompt_list) - ## load videos + # load videos video_filepaths = get_target_filelist(input_dir, ext="mp4") video_filenames = [ os.path.split(video_filepath)[-1] for video_filepath in video_filepaths diff --git a/videotuna/utils/lightning_utils.py b/videotuna/utils/lightning_utils.py index 87821255..21b27c97 100644 --- a/videotuna/utils/lightning_utils.py +++ b/videotuna/utils/lightning_utils.py @@ -1,6 +1,6 @@ import inspect from argparse import ArgumentParser -from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, cast +from typing import Any, Callable, Dict, List, Tuple, Type, Union import pytorch_lightning as pl @@ -149,7 +149,7 @@ def add_trainer_args_to_parser(cls, parent_parser, use_argument_group=True): raise RuntimeError("Please only pass an `ArgumentParser` instance.") if use_argument_group: group_name = _get_abbrev_qualified_cls_name(cls) - parser: _ADD_ARGPARSE_RETURN = parent_parser.add_argument_group(group_name) + parser = parent_parser.add_argument_group(group_name) else: parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -212,7 +212,7 @@ def add_trainer_args_to_parser(cls, parent_parser, use_argument_group=True): required=(arg_default == inspect._empty), **arg_kwargs, ) - except: + except Exception: # TODO: check the argument appending to the parser pass diff --git a/videotuna/utils/load_weights.py b/videotuna/utils/load_weights.py index 5a7e8763..f78ea5f9 100755 --- a/videotuna/utils/load_weights.py +++ b/videotuna/utils/load_weights.py @@ -1,20 +1,17 @@ import copy import logging - -from omegaconf import OmegaConf - -mainlogger = logging.getLogger("mainlogger") - from collections import OrderedDict import torch +from omegaconf import OmegaConf from safetensors import safe_open -from torch import nn from videotuna.utils.common_utils import instantiate_from_config # from lvdm.personalization.lora import net_load_lora +mainlogger = logging.getLogger("mainlogger") + def expand_conv_kernel(pretrained_dict): """expand 2d conv parameters from 4D -> 5D""" @@ -29,9 +26,9 @@ def load_from_pretrainedSD_checkpoint( model, pretained_ckpt, expand_to_3d=True, adapt_keyname=False ): mainlogger.info( - f"------------------- Load pretrained SD weights -------------------" + "------------------- Load pretrained SD weights -------------------" ) - sd_state_dict = torch.load(pretained_ckpt, map_location=f"cpu") + sd_state_dict = torch.load(pretained_ckpt, map_location="cpu") if "state_dict" in list(sd_state_dict.keys()): sd_state_dict = sd_state_dict["state_dict"] model_state_dict = model.state_dict() @@ -83,8 +80,8 @@ def load_from_pretrainedSD_checkpoint( # load the new state dict try: model.load_state_dict(model_state_dict) - except: - state_dict = torch.load(model_state_dict, map_location=f"cpu") + except Exception: + state_dict = torch.load(model_state_dict, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] model_state_dict = model.state_dict() @@ -96,9 +93,7 @@ def load_from_pretrainedSD_checkpoint( model_state_dict.update(state_dict) model.load_state_dict(model_state_dict) - mainlogger.info( - f"---------------------------- Finish! ----------------------------" - ) + mainlogger.info("---------------------------- Finish! ----------------------------") return model, empty_paras @@ -159,7 +154,7 @@ def load_partial_weights( model_dict_ori = copy.deepcopy(model_dict) mainlogger.info( - f"-------------- Load pretrained LDM weights --------------------------" + "-------------- Load pretrained LDM weights --------------------------" ) mainlogger.info(f"Num of parameters of target model: {len(model_dict.keys())}") mainlogger.info(f"Num of parameters of source model: {len(pretrained_dict.keys())}") @@ -209,7 +204,7 @@ def load_partial_weights( # load the new state dict try: model2.load_state_dict(model_dict) - except: + except Exception: # if parameter size mismatch, skip them skipped = [] for n, p in model_dict.items(): @@ -229,7 +224,7 @@ def load_partial_weights( mainlogger.info(f"Empty parameters: {len(empty_paras)} ") # import pdb;pdb.set_trace() - mainlogger.info(f"-------------- Finish! --------------------------") + mainlogger.info("-------------- Finish! --------------------------") return model2, empty_paras @@ -318,7 +313,7 @@ def convert_lora( LORA_PREFIX_TEXT_ENCODER + "_" )[-1].split("_") curr_layer = model - # if type(model.cond_stage_model) == "FrozenOpenCLIPEmbedder": + # if isinstance(model.cond_stage_model, FrozenOpenCLIPEmbedder): # curr_layer = model.cond_stage_model.model elif LORA_PREFIX_UNET in key: # else: diff --git a/videotuna/utils/save_video.py b/videotuna/utils/save_video.py index a93a7706..77505d3f 100755 --- a/videotuna/utils/save_video.py +++ b/videotuna/utils/save_video.py @@ -171,12 +171,11 @@ def save_img_grid(grid, path, rescale): f.write(f"idx={i}, txt={txt}\n") f.close() elif isinstance(value, torch.Tensor) and value.dim() == 5: - ## save video grids + # save video grids video = value # b,c,t,h,w - ## only save grayscale or rgb mode + # only save grayscale or rgb mode if video.shape[1] != 1 and video.shape[1] != 3: continue - n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(1)) @@ -201,12 +200,11 @@ def save_img_grid(grid, path, rescale): path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename)) # save_img_grid(grid, path, rescale) elif isinstance(value, torch.Tensor) and value.dim() == 4: - ## save image grids + # save image grids img = value - ## only save grayscale or rgb mode + # only save grayscale or rgb mode if img.shape[1] != 1 and img.shape[1] != 3: continue - n = img.shape[0] grid = torchvision.utils.make_grid(img, nrow=1) path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename)) save_img_grid(grid, path, rescale) @@ -240,7 +238,7 @@ def prepare_to_log(batch_logs, max_images=100000, clamp=True): return batch_logs -# ---------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------- def fill_with_black_squares(video, desired_len: int) -> Tensor: @@ -258,7 +256,7 @@ def fill_with_black_squares(video, desired_len: int) -> Tensor: ) -# ---------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------- def load_num_videos(data_path, num_videos): # first argument can be either data_path of np array if isinstance(data_path, str): @@ -276,7 +274,7 @@ def load_num_videos(data_path, num_videos): def npz_to_video_grid( data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True ): - # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1 + # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1 # noqa: E501 if isinstance(data_path, str): videos = load_num_videos(data_path, num_videos) elif isinstance(data_path, np.ndarray): diff --git a/videotuna/utils/train_utils.py b/videotuna/utils/train_utils.py index 58542765..f0639a56 100755 --- a/videotuna/utils/train_utils.py +++ b/videotuna/utils/train_utils.py @@ -1,23 +1,14 @@ -import argparse -import glob import logging -import multiprocessing as mproc import os -import sys from collections import OrderedDict -from omegaconf import OmegaConf -from packaging import version - -mainlogger = logging.getLogger("mainlogger") - -from collections import OrderedDict - -import pytorch_lightning as pl import torch +from omegaconf import OmegaConf from videotuna.utils.load_weights import load_from_pretrainedSD_checkpoint +mainlogger = logging.getLogger("mainlogger") + def init_workspace(name, logdir, model_config, lightning_config, rank=0): workdir = os.path.join(logdir, name) @@ -25,7 +16,8 @@ def init_workspace(name, logdir, model_config, lightning_config, rank=0): cfgdir = os.path.join(workdir, "configs") loginfo = os.path.join(workdir, "loginfo") - # Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower) + # Create logdirs and save configs + # (all ranks will do to avoid missing directory error if rank:0 is slower) os.makedirs(workdir, exist_ok=True) os.makedirs(ckptdir, exist_ok=True) os.makedirs(cfgdir, exist_ok=True) @@ -90,7 +82,8 @@ def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger): if "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks: mainlogger.info( - "Caution: Saving checkpoints every n train steps without deleting. This might require some free space." + "Caution: Saving checkpoints every n train steps without deleting." + " This might require some free space." ) default_metrics_over_trainsteps_ckpt_dict = { "metrics_over_trainsteps_checkpoint": { @@ -159,7 +152,7 @@ def get_trainer_strategy(lightning_config): def load_checkpoints(model, model_cfg): - ## special load setting for adapter training + # special load setting for adapter training if check_config_attribute(model_cfg, "adapter_only"): pretrained_ckpt = model_cfg.pretrained_checkpoint assert os.path.exists(pretrained_ckpt), ( @@ -170,7 +163,7 @@ def load_checkpoints(model, model_cfg): ) print(f"Loading model from {pretrained_ckpt}") ## only load weight for the backbone model (e.g. latent diffusion model) - state_dict = torch.load(pretrained_ckpt, map_location=f"cpu") + state_dict = torch.load(pretrained_ckpt, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] else: @@ -208,7 +201,7 @@ def load_checkpoints(model, model_cfg): for key in pl_sd["module"].keys(): new_pl_sd[key[16:]] = pl_sd["module"][key] model.load_state_dict(new_pl_sd) - except: + except Exception: if "state_dict" in pl_sd.keys(): model.load_state_dict(pl_sd["state_dict"], strict=False) else: @@ -217,7 +210,7 @@ def load_checkpoints(model, model_cfg): """ try: model = model.load_from_checkpoint(pretrained_ckpt, **model_cfg.params) - except: + except Exception: mainlogger.info("[Warning] checkpoint NOT complete matched. To adapt by skipping ...") state_dict = torch.load(pretrained_ckpt, map_location=f"cpu") if "state_dict" in list(state_dict.keys()): @@ -276,7 +269,7 @@ def get_autoresume_path(logdir): gs = tmp["global_step"] mainlogger.info(f"[INFO] Resume from epoch {e}, global step {gs}!") del tmp - except: + except Exception: try: mainlogger.info("Load last.ckpt failed!") ckpts = sorted( @@ -301,7 +294,8 @@ def get_autoresume_path(logdir): else: resume_checkpt_path = None mainlogger.info( - f"[INFO] no checkpoint found in current workspace: {os.path.join(logdir, 'checkpoints')}" + f"[INFO] no checkpoint found in current workspace: " + f"{os.path.join(logdir, 'checkpoints')}" ) return resume_checkpt_path