Skip to content
17 changes: 13 additions & 4 deletions cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,19 @@ def main(args):
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

datasets = train_dataset_group.datasets
train_dataset_group_blueprint = blueprint["train_dataset_group"]
val_dataset_group_blueprint = blueprint["val_dataset_group"]

blueprint_dict = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint_dict["train_dataset_group"], training=False
)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint_dict["val_dataset_group"], training=False
)

all_datasets = train_dataset_group.datasets + val_dataset_group.datasets

if args.debug_mode is not None:
show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
return
Expand All @@ -195,7 +204,7 @@ def main(args):

# Encode images
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
for i, dataset in enumerate(datasets):
for i, dataset in enumerate(all_datasets):
logger.info(f"Encoding dataset [{i}]")
all_latent_cache_paths = []
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
Expand Down
14 changes: 10 additions & 4 deletions cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,15 @@ def main(args):
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

datasets = train_dataset_group.datasets
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint["train_dataset_group"],
training=False
)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(
blueprint["val_dataset_group"],
training=False
)
datasets = train_dataset_group.datasets + val_dataset_group.datasets

# define accelerator for fp8 inference
accelerator = None
Expand Down Expand Up @@ -163,4 +169,4 @@ def setup_parser():
parser = setup_parser()

args = parser.parse_args()
main(args)
main(args)
62 changes: 38 additions & 24 deletions dataset/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def validate_flex_dataset(dataset_config: dict):
{
"general": self.general_schema,
"datasets": [self.dataset_schema],
"val_datasets": [self.dataset_schema],
}
)
self.argparse_schema = self.__merge_dict(
Expand Down Expand Up @@ -182,31 +183,44 @@ class BlueprintGenerator:

def __init__(self, sanitizer: ConfigSanitizer):
self.sanitizer = sanitizer

# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:

def generate(self, user_config, argparse_namespace):
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)

argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
general_config = sanitized_user_config.get("general", {})

dataset_blueprints = []
for dataset_config in sanitized_user_config.get("datasets", []):
is_image_dataset = "target_frames" not in dataset_config
if is_image_dataset:
dataset_params_klass = ImageDatasetParams
else:
dataset_params_klass = VideoDatasetParams

params = self.generate_params_by_fallbacks(
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
)
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))

dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)

return Blueprint(dataset_group_blueprint)
# store the top-level "general" section
self.general_config = sanitized_user_config.get("general", {})

# parse training datasets
train_dataset_configs = sanitized_user_config.get("datasets", [])
train_blueprints = [
self.make_dataset_blueprint(cfg) for cfg in train_dataset_configs
]
train_dataset_group_blueprint = DatasetGroupBlueprint(train_blueprints)

# parse validation datasets
val_dataset_configs = sanitized_user_config.get("val_datasets", [])
val_blueprints = [
self.make_dataset_blueprint(cfg) for cfg in val_dataset_configs
]
val_dataset_group_blueprint = DatasetGroupBlueprint(val_blueprints)

return {
"train_dataset_group": train_dataset_group_blueprint,
"val_dataset_group": val_dataset_group_blueprint,
}

def make_dataset_blueprint(self, dataset_config):
# Decide whether it's an image dataset or video dataset
is_image_dataset = "target_frames" not in dataset_config
dataset_params_klass = ImageDatasetParams if is_image_dataset else VideoDatasetParams

# Merge from dataset_config + general_config
fallback_list = [
dataset_config,
self.general_config, # so we get caption_extension
]
params = self.generate_params_by_fallbacks(dataset_params_klass, fallback_list)

return DatasetBlueprint(is_image_dataset, params)

@staticmethod
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
Expand Down
85 changes: 84 additions & 1 deletion dataset/dataset_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,87 @@ The metadata with .json file will be supported in the near future.



-->
-->

────────────────────────────────────────────────────────────────────────
UPDATED DOCUMENTATION: INCLUDING A VALIDATION DATASET
────────────────────────────────────────────────────────────────────────

1. Overview of “val_datasets”
Just like you can have multiple datasets for training under the "[[datasets]]" key, you can also specify validation datasets under "[[val_datasets]]". The syntax and options for each validation dataset are exactly the same as for training. The script will:
• Load and cache your validation datasets the same way it does for training.
• Periodically compute a validation loss across these datasets (for example, once per epoch).
• Log the validation loss so you can monitor for over- or under-fitting.

2. Example TOML Configuration with Validation Datasets

Below is a minimal example that shows both training and validation datasets. Notice that “[[datasets]]” is used for training datasets, while “[[val_datasets]]” is reserved for any validation sets you want to include.

--------------------------------------------------------------------------------
[general]
caption_extension = ".txt"
batch_size = 1
enable_bucket = true
bucket_no_upscale = false

# PRIMARY TRAIN DATASETS
[[datasets]]
resolution = [640, 480]
video_directory = "path/to/training/video"
cache_directory = "path/to/cache/training/video"
frame_extraction = "head"
target_frames = [48]

[[datasets]]
resolution = [640, 480]
image_directory = "path/to/training/image"
cache_directory = "path/to/cache/training/image"

# ... you can add more [[datasets]] blocks if you have more training subsets ...

# VALIDATION DATASETS
[[val_datasets]]
resolution = [640, 480]
video_directory = "path/to/validation/video"
cache_directory = "path/to/cache/validation/video"
frame_extraction = "head"
target_frames = [48]

[[val_datasets]]
resolution = [640, 480]
image_directory = "path/to/validation/image"
cache_directory = "path/to/cache/validation/image"

# ... you can add more [[val_datasets]] blocks if you have more validation subsets ...
--------------------------------------------------------------------------------

Notes on usage:
• The script will treat the “[[datasets]]” entries as training data and “[[val_datasets]]” as validation data. Both sets can be a mix of images or videos.
• Each dataset or validation dataset must have a unique “cache_directory” to avoid overwriting latents or text-encoder caches.
• All of the same parameters (resolution, caption_extension, num_repeats, batch_size, etc.) work in exactly the same way under “val_datasets” as they do under “datasets.”

3. Running the Script with Validation
Once you have listed your training and validation datasets, you can run the training script as normal. For example:

--------------------------------------------------------------------------------
accelerate launch hv_train_network.py \
--dit path/to/DiT/model \
--dataset_config path/to/config_with_val.toml \
--max_train_epochs 20 \
... other arguments ...
--------------------------------------------------------------------------------

During training, the script will:
• Load and batch the training datasets.
• Perform training epochs, computing training loss.
• After each epoch (or at a specified interval), it will compute the validation loss using all “val_datasets.”
• Log the validation performance (for example, “val_loss=...”) to TensorBoard and/or WandB if logging is enabled.

────────────────────────────────────────────────────────────────────────
ADDITIONAL TIPS
────────────────────────────────────────────────────────────────────────
• If you do not wish to do any validation, simply omit the “[[val_datasets]]” sections.
• You can add multiple validation blocks, each pointing to different image or video folders or JSONL metadata.
• The script merges all validation datasets into a single DataLoader when it computes validation loss.

With these changes to your config file, you can systematically evaluate your model’s performance after each epoch (or some other schedule) by leveraging the “val_datasets” blocks.
14 changes: 12 additions & 2 deletions hunyuan_model/token_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,18 @@ def forward(
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting a crash on my 4090 about type casting I think, this resolved it.

context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
if x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# do the operation in a safer fallback type, e.g. float16 (or bf16)
safe_x = x.float() # from float8 → float
safe_mask = mask.float().unsqueeze(-1)
numerator = (safe_x * safe_mask).sum(dim=1)
denominator = safe_mask.sum(dim=1).clamp_min(1e-8) # avoid div-by-zero
out = numerator / denominator
context_aware_representations = out.to(x.dtype) # cast back to float8
else:
# the old logic for other dtypes
mask_float = mask.float().unsqueeze(-1)
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations

Expand Down
120 changes: 115 additions & 5 deletions hv_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,8 @@ def train(self, args):
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint["train_dataset_group"], training=True)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint["val_dataset_group"], training=True)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -1282,6 +1283,111 @@ def train(self, args):
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

val_collator = collator_class(current_epoch, current_step, ds_for_collator) # same as train
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group,
batch_size=1,
shuffle=False,
collate_fn=val_collator,
num_workers=args.max_data_loader_n_workers
)

def validate(
accelerator,
transformer,
val_dataloader,
noise_scheduler,
args,
):
# Unwrap and switch to inference mode
unwrapped_model = accelerator.unwrap_model(transformer)
unwrapped_model.switch_block_swap_for_inference()
unwrapped_model.eval()

fixed_timesteps = [500]
fixed_seed = 42
losses = []

losses = []
pos_embed_cache = {}

with torch.no_grad():
for step, batch in enumerate(val_dataloader):
latents, llm_embeds, llm_mask, clip_embeds = batch

# Put everything on the correct device/dtype
latents = latents.to(accelerator.device, dtype=torch.float32)
latents = latents * vae_module.SCALING_FACTOR
llm_embeds = llm_embeds.to(accelerator.device, dtype=unwrapped_model.dtype)
llm_mask = llm_mask.to(accelerator.device)
clip_embeds = clip_embeds.to(accelerator.device, dtype=unwrapped_model.dtype)

# Override RNG inside a local fork_rng context:
with torch.random.fork_rng(devices=[accelerator.device]):
torch.manual_seed(fixed_seed)
# If on CUDA:
if accelerator.device.type == "cuda":
torch.cuda.manual_seed(fixed_seed)

noise = torch.randn_like(latents) # Now gives stable noise each time.

# Pick a fixed timestep. If you want to cycle through multiple for each batch,
# just pick one from fixed_timesteps by index, or average the losses across them.
t = fixed_timesteps[0]
timesteps = torch.full(
(latents.size(0),),
t,
device=accelerator.device,
dtype=torch.float32
)

# Make "noisy_model_input" deterministically
sigma = get_sigmas(
noise_scheduler,
timesteps,
accelerator.device,
latents.dim(),
latents.dtype
)
noisy_model_input = sigma * noise + (1.0 - sigma) * latents

# (Optional) position embeddings
pos_emb_shape = latents.shape[-3:]
if pos_emb_shape not in pos_embed_cache:
freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(unwrapped_model, pos_emb_shape)
pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
else:
freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]

# Forward pass
guidance_vec = torch.full(
(latents.size(0),),
float(args.guidance_scale),
device=accelerator.device
)
pred = unwrapped_model(
noisy_model_input,
timesteps,
text_states=llm_embeds,
text_mask=llm_mask,
text_states_2=clip_embeds,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
guidance=guidance_vec,
return_dict=False,
)

# MSE target
target = noise - latents
val_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean")
losses.append(val_loss.item())

# Switch back to train mode if necessary
unwrapped_model.train()
unwrapped_model.switch_block_swap_for_training()

return float(np.mean(losses))

# calculate max_train_steps
if args.max_train_epochs is not None:
Expand Down Expand Up @@ -1603,15 +1709,14 @@ def remove_model(old_ckpt_name):
noisy_model_input.requires_grad_(True)
guidance_vec.requires_grad_(True)

pos_emb_shape = latents.shape[1:]
pos_emb_shape = latents.shape[2:] # (frames, height, width)
if pos_emb_shape not in pos_embed_cache:
freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:])
# freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
# freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, pos_emb_shape)
pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
else:
freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]


# call DiT
latents = latents.to(device=accelerator.device, dtype=network_dtype)
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
Expand Down Expand Up @@ -1741,6 +1846,11 @@ def remove_model(old_ckpt_name):

sample_images(accelerator, args, epoch + 1, global_step, vae, transformer, sample_parameters, dit_dtype)
optimizer_train_fn()

# Do validation
val_loss = validate(accelerator, transformer, val_dataloader, noise_scheduler, args)
accelerator.print(f"[Epoch {epoch+1}] val_loss={val_loss:0.5f}")
accelerator.log({"val_loss": val_loss}, step=global_step)

# end of epoch

Expand Down