Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cache_latents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
from typing import Optional, Union
from typing import Optional, Union, List

import numpy as np
import torch
Expand Down Expand Up @@ -86,7 +86,7 @@ def show_console(


def show_datasets(
datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
datasets: List[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
):
print(f"d: next dataset, q: quit")

Expand Down Expand Up @@ -119,7 +119,7 @@ def show_datasets(
batch_index += 1


def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: List[ItemInfo]):
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
if len(contents.shape) == 4:
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
Expand Down
6 changes: 3 additions & 3 deletions cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
from typing import Optional, Union
from typing import Optional, Union, List

import numpy as np
import torch
Expand All @@ -22,7 +22,7 @@
logging.basicConfig(level=logging.INFO)


def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, List[str]]):
data_type = "video" # video only, image is not supported
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)

Expand All @@ -33,7 +33,7 @@ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):


def encode_and_save_batch(
text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
text_encoder: TextEncoder, batch: List[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
):
prompts = [item.caption for item in batch]
# print(prompts)
Expand Down
56 changes: 28 additions & 28 deletions dataset/image_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import random
import time
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union, List, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -103,7 +103,7 @@ def divisible_by(num: int, divisor: int) -> int:
return num - num % divisor


def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
"""
Resize the image to the bucket resolution.
"""
Expand Down Expand Up @@ -147,8 +147,8 @@ def __init__(
self,
item_key: str,
caption: str,
original_size: tuple[int, int],
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
original_size: Tuple[int, int],
bucket_size: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None,
frame_count: Optional[int] = None,
content: Optional[np.ndarray] = None,
latent_cache_path: Optional[str] = None,
Expand Down Expand Up @@ -272,7 +272,7 @@ def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_u
# calculate aspect ratio to find the nearest resolution
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])

def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
def get_bucket_resolution(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
"""
return the bucket resolution for the given image size, (width, height)
"""
Expand All @@ -294,7 +294,7 @@ def load_video(
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> list[np.ndarray]:
) -> List[np.ndarray]:
container = av.open(video_path)
video = []
bucket_reso = None
Expand All @@ -320,7 +320,7 @@ def load_video(

class BucketBatchManager:

def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
def __init__(self, bucketed_item_info: Dict[Tuple[int, int], List[ItemInfo]], batch_size: int):
self.batch_size = batch_size
self.buckets = bucketed_item_info
self.bucket_resos = list(self.buckets.keys())
Expand Down Expand Up @@ -402,7 +402,7 @@ def set_caption_only(self, caption_only: bool):
def is_indexable(self):
return False

def get_caption(self, idx: int) -> tuple[str, str]:
def get_caption(self, idx: int) -> Tuple[str, str]:
"""
Returns caption. May not be called if is_indexable() returns False.
"""
Expand All @@ -422,7 +422,7 @@ class ImageDatasource(ContentDatasource):
def __init__(self):
super().__init__()

def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
def get_image_data(self, idx: int) -> Tuple[str, Image.Image, str]:
"""
Returns image data as a tuple of image path, image, and caption for the given index.
Key must be unique and valid as a file name.
Expand All @@ -449,15 +449,15 @@ def is_indexable(self):
def __len__(self):
return len(self.image_paths)

def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
def get_image_data(self, idx: int) -> Tuple[str, Image.Image, str]:
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")

_, caption = self.get_caption(idx)

return image_path, image, caption

def get_caption(self, idx: int) -> tuple[str, str]:
def get_caption(self, idx: int) -> Tuple[str, str]:
image_path = self.image_paths[idx]
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
with open(caption_path, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -517,7 +517,7 @@ def is_indexable(self):
def __len__(self):
return len(self.data)

def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
def get_image_data(self, idx: int) -> Tuple[str, Image.Image, str]:
data = self.data[idx]
image_path = data["image_path"]
image = Image.open(image_path).convert("RGB")
Expand All @@ -526,7 +526,7 @@ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:

return image_path, image, caption

def get_caption(self, idx: int) -> tuple[str, str]:
def get_caption(self, idx: int) -> Tuple[str, str]:
data = self.data[idx]
image_path = data["image_path"]
caption = data["caption"]
Expand Down Expand Up @@ -577,7 +577,7 @@ def get_video_data_from_path(
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> tuple[str, list[Image.Image], str]:
) -> Tuple[str, List[Image.Image], str]:
# this method can resize the video if bucket_selector is given to reduce the memory usage

start_frame = start_frame if start_frame is not None else self.start_frame
Expand Down Expand Up @@ -625,15 +625,15 @@ def get_video_data(
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> tuple[str, list[Image.Image], str]:
) -> Tuple[str, List[Image.Image], str]:
video_path = self.video_paths[idx]
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)

_, caption = self.get_caption(idx)

return video_path, video, caption

def get_caption(self, idx: int) -> tuple[str, str]:
def get_caption(self, idx: int) -> Tuple[str, str]:
video_path = self.video_paths[idx]
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
with open(caption_path, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -693,7 +693,7 @@ def get_video_data(
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> tuple[str, list[Image.Image], str]:
) -> Tuple[str, List[Image.Image], str]:
data = self.data[idx]
video_path = data["video_path"]
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
Expand All @@ -702,7 +702,7 @@ def get_video_data(

return video_path, video, caption

def get_caption(self, idx: int) -> tuple[str, str]:
def get_caption(self, idx: int) -> Tuple[str, str]:
data = self.data[idx]
video_path = data["video_path"]
caption = data["caption"]
Expand Down Expand Up @@ -823,7 +823,7 @@ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: Conten
datasource.set_caption_only(True)
executor = ThreadPoolExecutor(max_workers=num_workers)

data: list[ItemInfo] = []
data: List[ItemInfo] = []
futures = []

def aggregate_future(consume_all: bool = False):
Expand Down Expand Up @@ -921,7 +921,7 @@ def retrieve_latent_cache_batches(self, num_workers: int):
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
executor = ThreadPoolExecutor(max_workers=num_workers)

batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
batches: Dict[Tuple[int, int], List[ItemInfo]] = {} # (width, height) -> [ItemInfo]
futures = []

def aggregate_future(consume_all: bool = False):
Expand Down Expand Up @@ -961,7 +961,7 @@ def submit_batch(flush: bool = False):

for fetch_op in self.datasource:

def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
def fetch_and_resize(op: callable) -> Tuple[Tuple[int, int], str, Image.Image, str]:
image_key, image, caption = op()
image: Image.Image
image_size = image.size
Expand Down Expand Up @@ -998,7 +998,7 @@ def prepare_for_training(self):
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))

# assign cache files to item info
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
bucketed_item_info: Dict[tuple[int, int], List[ItemInfo]] = {} # (width, height) -> [ItemInfo]
for cache_file in latent_cache_files:
tokens = os.path.basename(cache_file).split("_")

Expand Down Expand Up @@ -1053,7 +1053,7 @@ def __init__(
frame_extraction: Optional[str] = "head",
frame_stride: Optional[int] = 1,
frame_sample: Optional[int] = 1,
target_frames: Optional[list[int]] = None,
target_frames: Optional[List[int]] = None,
video_directory: Optional[str] = None,
video_jsonl_file: Optional[str] = None,
cache_directory: Optional[str] = None,
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def retrieve_latent_cache_batches(self, num_workers: int):
executor = ThreadPoolExecutor(max_workers=num_workers)

# key: (width, height, frame_count), value: [ItemInfo]
batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
batches: Dict[Tuple[int, int, int], List[ItemInfo]] = {}
futures = []

def aggregate_future(consume_all: bool = False):
Expand Down Expand Up @@ -1184,9 +1184,9 @@ def submit_batch(flush: bool = False):

for operator in self.datasource:

def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
def fetch_and_resize(op: callable) -> Tuple[Tuple[int, int], str, List[np.ndarray], str]:
video_key, video, caption = op()
video: list[np.ndarray]
video: List[np.ndarray]
frame_size = (video[0].shape[1], video[0].shape[0])

# resize if necessary
Expand Down Expand Up @@ -1223,7 +1223,7 @@ def prepare_for_training(self):
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))

# assign cache files to item info
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
bucketed_item_info: Dict[Tuple[int, int, int], List[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
for cache_file in latent_cache_files:
tokens = os.path.basename(cache_file).split("_")

Expand Down Expand Up @@ -1274,7 +1274,7 @@ def __getitem__(self, idx):
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
super().__init__(datasets)
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
self.datasets: List[Union[ImageDataset, VideoDataset]] = datasets
self.num_train_items = 0
for dataset in self.datasets:
self.num_train_items += dataset.num_train_items
Expand Down
4 changes: 2 additions & 2 deletions hunyuan_model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,11 @@ def forward(
if self.attn_mode == "torch" and not self.split_attn:
# initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
bs = img.shape[0]
attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
attn_mask = torch.zeros((bs, 1, max_seqlen_q), dtype=torch.bool, device=text_mask.device)

# set attention mask with total_len
for i in range(bs):
attn_mask[i, :, : total_len[i], : total_len[i]] = True
attn_mask[i, :, : total_len[i]] = True
Comment on lines -759 to +763
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems to result in the following error in PyTorch 2.5.1:

x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
RuntimeError: The expanded size of the tensor (24) must match the existing size (3) at non-singleton dimension 1.  Target sizes: [3, 24, 2296, 2296].  Tensor sizes: [3, 1, 2296]

Could you please revert this fix? Without this fix, it would work with --split_attn on versions prior to PyTorch 2.5.1.

total_len = None # means we don't use split_attn

freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
Expand Down
2 changes: 1 addition & 1 deletion hunyuan_model/token_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def forward(
else:
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
context_aware_representations = self.c_embedder(context_aware_representations.to(dtype=x.dtype))
c = timestep_aware_representations + context_aware_representations

x = self.input_embedder(x)
Expand Down
4 changes: 2 additions & 2 deletions hv_generate_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import os
import time
from typing import Optional, Union
from typing import Optional, Union, List

import numpy as np
import torch
Expand Down Expand Up @@ -137,7 +137,7 @@ def save_images_grid(videos: torch.Tensor, parent_dir: str, image_name: str, res
# region Encoding prompt


def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
def encode_prompt(prompt: Union[str, List[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
r"""
Encodes the prompt into text encoder hidden states.

Expand Down
6 changes: 3 additions & 3 deletions hv_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time
import json
from multiprocessing import Value
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import accelerate
import numpy as np
from packaging.version import Version
Expand Down Expand Up @@ -206,7 +206,7 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict


def load_prompts(prompt_file: str) -> list[Dict]:
def load_prompts(prompt_file: str) -> List[Dict]:
# read prompts
if prompt_file.endswith(".txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -359,7 +359,7 @@ def encode_for_text_encoder(text_encoder, is_llm=True):

return sample_parameters

def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
def get_optimizer(self, args, trainable_params: List[torch.nn.Parameter]) -> Tuple[str, str, torch.optim.Optimizer]:
# adamw, adamw8bit, adafactor

optimizer_type = args.optimizer_type.lower()
Expand Down
6 changes: 3 additions & 3 deletions hv_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time
import json
from multiprocessing import Value
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import accelerate
import numpy as np
from packaging.version import Version
Expand Down Expand Up @@ -218,7 +218,7 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict


def load_prompts(prompt_file: str) -> list[Dict]:
def load_prompts(prompt_file: str) -> List[Dict]:
# read prompts
if prompt_file.endswith(".txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -434,7 +434,7 @@ def encode_for_text_encoder(text_encoder, is_llm=True):

return sample_parameters

def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
def get_optimizer(self, args, trainable_params: List[torch.nn.Parameter]) -> Tuple[str, str, torch.optim.Optimizer]:
# adamw, adamw8bit, adafactor

optimizer_type = args.optimizer_type.lower()
Expand Down
Loading