|
7 | 7 | """
|
8 | 8 |
|
9 | 9 | import logging
|
10 |
| -import os |
11 | 10 | import random
|
12 | 11 | from contextlib import contextmanager
|
13 | 12 | from functools import partial
|
14 | 13 |
|
15 | 14 | import numpy as np
|
16 |
| -from einops import rearrange, repeat |
17 |
| -from tqdm import tqdm |
18 |
| - |
19 |
| -mainlogger = logging.getLogger("mainlogger") |
20 |
| - |
21 | 15 | import peft
|
22 | 16 | import pytorch_lightning as pl
|
23 | 17 | import torch
|
24 |
| -import torch.nn as nn |
| 18 | +from einops import rearrange, repeat |
25 | 19 | from pytorch_lightning.utilities import rank_zero_only
|
26 | 20 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
|
27 | 21 | from torchvision.utils import make_grid
|
| 22 | +from tqdm import tqdm |
28 | 23 |
|
29 | 24 | from videotuna.base.ddim import DDIMSampler
|
30 |
| -from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl |
| 25 | +from videotuna.base.distributions import DiagonalGaussianDistribution |
31 | 26 | from videotuna.base.ema import LitEma
|
32 |
| -from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr |
33 | 27 |
|
34 | 28 | # import rlhf utils
|
35 | 29 | from videotuna.lvdm.models.rlhf_utils.batch_ddim import batch_ddim_sampling
|
36 | 30 | from videotuna.lvdm.models.rlhf_utils.reward_fn import aesthetic_loss_fn
|
37 | 31 | from videotuna.lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler
|
38 |
| -from videotuna.lvdm.modules.utils import ( |
39 |
| - default, |
40 |
| - disabled_train, |
41 |
| - exists, |
42 |
| - extract_into_tensor, |
43 |
| - noise_like, |
44 |
| -) |
| 32 | +from videotuna.lvdm.modules.utils import default, disabled_train, extract_into_tensor |
45 | 33 | from videotuna.utils.common_utils import instantiate_from_config
|
46 | 34 |
|
47 | 35 | __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
|
48 | 36 |
|
49 | 37 |
|
| 38 | +mainlogger = logging.getLogger("mainlogger") |
| 39 | + |
| 40 | + |
50 | 41 | class DDPMFlow(pl.LightningModule):
|
51 | 42 | # classic DDPM with Gaussian diffusion, in image space
|
52 | 43 | def __init__(
|
@@ -430,7 +421,7 @@ def load_lora_from_ckpt(self, model, path):
|
430 | 421 | f"Parameter {key} from lora_state_dict was not copied to the model."
|
431 | 422 | )
|
432 | 423 | # print(f"Parameter {key} from lora_state_dict was not copied to the model.")
|
433 |
| - print(f"All Parameters was copied successfully.") |
| 424 | + print("All Parameters was copied successfully.") |
434 | 425 |
|
435 | 426 | def inject_lora(self):
|
436 | 427 | """inject lora into the denoising module.
|
@@ -519,7 +510,7 @@ def __init__(
|
519 | 510 |
|
520 | 511 | try:
|
521 | 512 | self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
522 |
| - except: |
| 513 | + except Exception: |
523 | 514 | self.num_downs = 0
|
524 | 515 | if not scale_by_std:
|
525 | 516 | self.scale_factor = scale_factor
|
@@ -1586,7 +1577,7 @@ def configure_optimizers(self):
|
1586 | 1577 |
|
1587 | 1578 | if self.cond_stage_trainable:
|
1588 | 1579 | params_cond_stage = [
|
1589 |
| - p for p in self.cond_stage_model.parameters() if p.requires_grad == True |
| 1580 | + p for p in self.cond_stage_model.parameters() if p.requires_grad is True |
1590 | 1581 | ]
|
1591 | 1582 | mainlogger.info(
|
1592 | 1583 | f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model."
|
|
0 commit comments