Skip to content
17 changes: 13 additions & 4 deletions cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,19 @@ def main(args):
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

datasets = train_dataset_group.datasets
train_dataset_group_blueprint = blueprint["train_dataset_group"]
val_dataset_group_blueprint = blueprint["val_dataset_group"]

blueprint_dict = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint_dict["train_dataset_group"], training=False
)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint_dict["val_dataset_group"], training=False
)

all_datasets = train_dataset_group.datasets + val_dataset_group.datasets

if args.debug_mode is not None:
show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
return
Expand All @@ -195,7 +204,7 @@ def main(args):

# Encode images
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
for i, dataset in enumerate(datasets):
for i, dataset in enumerate(all_datasets):
logger.info(f"Encoding dataset [{i}]")
all_latent_cache_paths = []
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
Expand Down
137 changes: 70 additions & 67 deletions cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,23 @@
import numpy as np
import torch
from tqdm import tqdm
import accelerate

from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
import accelerate

from dataset.image_video_dataset import ItemInfo, save_text_encoder_output_cache
from hunyuan_model import text_encoder as text_encoder_module
from hunyuan_model.text_encoder import TextEncoder

import logging

from utils.model_utils import str_to_dtype

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
data_type = "video" # video only, image is not supported
data_type = "video" # typically 'video'; can be changed if needed
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)

with torch.no_grad():
Expand All @@ -35,105 +33,114 @@ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
def encode_and_save_batch(
text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
):
# aggregate prompts
prompts = [item.caption for item in batch]
# print(prompts)

# encode prompt
# encode
if accelerator is not None:
with accelerator.autocast():
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
else:
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)

# # convert to fp16 if needed
# if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
# prompt_embeds = prompt_embeds.to(text_encoder.dtype)

# save prompt cache
# save
for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
save_text_encoder_output_cache(item, embed, mask, is_llm)


def main(args):
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
# pick device
device = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(device)

# Load dataset config
# load config
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

datasets = train_dataset_group.datasets
# parse train + val dataset groups
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint["train_dataset_group"], training=False
)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint["val_dataset_group"], training=False
)

# combine
all_datasets = train_dataset_group.datasets + val_dataset_group.datasets

# define accelerator for fp8 inference
# optional: if you had a debug mode, you could do e.g.:
# if args.debug_mode:
# debug_display_function(all_datasets, ...)
# return

# optional: set up accelerator for fp8, if desired
accelerator = None
if args.fp8_llm:
accelerator = accelerate.Accelerator(mixed_precision="fp16")

# define encode function
# prepare batch encoding function
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)

all_cache_files_for_dataset = [] # exisiting cache files
all_cache_paths_for_dataset = [] # all cache paths in the dataset
for dataset in datasets:
all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
all_cache_files = set(all_cache_files)
all_cache_files_for_dataset.append(all_cache_files)

# each dataset can produce text-encoder batches
all_cache_files_for_dataset = []
all_cache_paths_for_dataset = []
for dataset in all_datasets:
cache_files = dataset.get_all_text_encoder_output_cache_files()
cache_files = [os.path.normpath(f) for f in cache_files]
all_cache_files_for_dataset.append(set(cache_files))
all_cache_paths_for_dataset.append(set())

# helper to encode with a given text encoder
def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool):
for i, dataset in enumerate(datasets):
logger.info(f"Encoding dataset [{i}]")
all_cache_files = all_cache_files_for_dataset[i]
all_cache_paths = all_cache_paths_for_dataset[i]
for i, dataset in enumerate(all_datasets):
logger.info(f"Encoding dataset [{i}] with text encoder {('LLM' if is_llm else 'CLIPL')}")
existing_cache_files = all_cache_files_for_dataset[i]
new_seen_cache = all_cache_paths_for_dataset[i]

# stream batches
for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
# update cache files (it's ok if we update it multiple times)
all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
# record these paths
new_seen_cache.update(os.path.normpath(item.text_encoder_output_cache_path) for item in batch)

# skip existing cache files
# skip existing
if args.skip_existing:
filtered_batch = [
item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
item for item in batch
if os.path.normpath(item.text_encoder_output_cache_path) not in existing_cache_files
]
# print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
if len(filtered_batch) == 0:
if not filtered_batch:
continue
batch = filtered_batch

bs = args.batch_size if args.batch_size is not None else len(batch)
for i in range(0, len(batch), bs):
encode_and_save_batch(text_encoder, batch[i : i + bs], is_llm, accelerator)
# chunk into mini-batches for memory
bs = args.batch_size or len(batch)
for start_idx in range(0, len(batch), bs):
sub_batch = batch[start_idx : start_idx + bs]
encode_and_save_batch(text_encoder, sub_batch, is_llm, accelerator)

# Load Text Encoder 1
# load + encode with text encoder 1
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
logger.info(f"loading text encoder 1: {args.text_encoder1}")
logger.info(f"Loading text encoder 1 from {args.text_encoder1}")
text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
text_encoder_1.to(device=device)

# Encode with Text Encoder 1
logger.info("Encoding with Text Encoder 1")
encode_for_text_encoder(text_encoder_1, is_llm=True)
del text_encoder_1

# Load Text Encoder 2
logger.info(f"loading text encoder 2: {args.text_encoder2}")
# load + encode with text encoder 2
logger.info(f"Loading text encoder 2 from {args.text_encoder2}")
text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
text_encoder_2.to(device=device)

# Encode with Text Encoder 2
logger.info("Encoding with Text Encoder 2")
encode_for_text_encoder(text_encoder_2, is_llm=False)
del text_encoder_2

# remove cache files not in dataset
for i, dataset in enumerate(datasets):
all_cache_files = all_cache_files_for_dataset[i]
all_cache_paths = all_cache_paths_for_dataset[i]
for cache_file in all_cache_files:
if cache_file not in all_cache_paths:
# remove old cache files not in dataset
for i, dataset in enumerate(all_datasets):
existing_cache_files = all_cache_files_for_dataset[i]
new_seen_cache = all_cache_paths_for_dataset[i]
for cache_file in existing_cache_files:
if cache_file not in new_seen_cache:
if args.keep_cache:
logger.info(f"Keep cache file not in the dataset: {cache_file}")
else:
Expand All @@ -143,24 +150,20 @@ def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool):

def setup_parser():
parser = argparse.ArgumentParser()

parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
parser.add_argument(
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
)
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 path or dir")
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 path or dir")
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g. cuda, cpu). Default auto.")
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="FP precision, default float16")
parser.add_argument("--fp8_llm", action="store_true", help="Use fp8 for Text Encoder 1 (LLM)")
parser.add_argument("--batch_size", type=int, default=None, help="Optional batch size override")
parser.add_argument("--num_workers", type=int, default=None, help="Number of DataLoader workers. Default = CPU-1")
parser.add_argument("--skip_existing", action="store_true", help="Skip items with existing cache files")
parser.add_argument("--keep_cache", action="store_true", help="Keep old cache files not in the dataset")
return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
main(args)
main(args)
62 changes: 38 additions & 24 deletions dataset/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def validate_flex_dataset(dataset_config: dict):
{
"general": self.general_schema,
"datasets": [self.dataset_schema],
"val_datasets": [self.dataset_schema],
}
)
self.argparse_schema = self.__merge_dict(
Expand Down Expand Up @@ -182,31 +183,44 @@ class BlueprintGenerator:

def __init__(self, sanitizer: ConfigSanitizer):
self.sanitizer = sanitizer

# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:

def generate(self, user_config, argparse_namespace):
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)

argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
general_config = sanitized_user_config.get("general", {})

dataset_blueprints = []
for dataset_config in sanitized_user_config.get("datasets", []):
is_image_dataset = "target_frames" not in dataset_config
if is_image_dataset:
dataset_params_klass = ImageDatasetParams
else:
dataset_params_klass = VideoDatasetParams

params = self.generate_params_by_fallbacks(
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
)
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))

dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)

return Blueprint(dataset_group_blueprint)
# store the top-level "general" section
self.general_config = sanitized_user_config.get("general", {})

# parse training datasets
train_dataset_configs = sanitized_user_config.get("datasets", [])
train_blueprints = [
self.make_dataset_blueprint(cfg) for cfg in train_dataset_configs
]
train_dataset_group_blueprint = DatasetGroupBlueprint(train_blueprints)

# parse validation datasets
val_dataset_configs = sanitized_user_config.get("val_datasets", [])
val_blueprints = [
self.make_dataset_blueprint(cfg) for cfg in val_dataset_configs
]
val_dataset_group_blueprint = DatasetGroupBlueprint(val_blueprints)

return {
"train_dataset_group": train_dataset_group_blueprint,
"val_dataset_group": val_dataset_group_blueprint,
}

def make_dataset_blueprint(self, dataset_config):
# Decide whether it's an image dataset or video dataset
is_image_dataset = "target_frames" not in dataset_config
dataset_params_klass = ImageDatasetParams if is_image_dataset else VideoDatasetParams

# Merge from dataset_config + general_config
fallback_list = [
dataset_config,
self.general_config, # so we get caption_extension
]
params = self.generate_params_by_fallbacks(dataset_params_klass, fallback_list)

return DatasetBlueprint(is_image_dataset, params)

@staticmethod
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
Expand Down
14 changes: 12 additions & 2 deletions hunyuan_model/token_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,18 @@ def forward(
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting a crash on my 4090 about type casting I think, this resolved it.

context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
if x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# do the operation in a safer fallback type, e.g. float16 (or bf16)
safe_x = x.float() # from float8 → float
safe_mask = mask.float().unsqueeze(-1)
numerator = (safe_x * safe_mask).sum(dim=1)
denominator = safe_mask.sum(dim=1).clamp_min(1e-8) # avoid div-by-zero
out = numerator / denominator
context_aware_representations = out.to(x.dtype) # cast back to float8
else:
# the old logic for other dtypes
mask_float = mask.float().unsqueeze(-1)
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations

Expand Down
Loading