Skip to content
Draft
8 changes: 5 additions & 3 deletions cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ def main(args):
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

datasets = train_dataset_group.datasets
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.train_dataset_group, training=True)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.val_dataset_group, training=True)
# Combine all datasets so that both training and validation sets are processed:
datasets = train_dataset_group.datasets + val_dataset_group.datasets

if args.debug_mode is not None:
show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
Expand Down Expand Up @@ -267,3 +268,4 @@ def setup_parser():

args = parser.parse_args()
main(args)

7 changes: 4 additions & 3 deletions cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def main(args):
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

datasets = train_dataset_group.datasets
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.train_dataset_group, training=True)
val_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.val_dataset_group, training=True)
datasets = train_dataset_group.datasets + val_dataset_group.datasets

# define accelerator for fp8 inference
accelerator = None
Expand Down Expand Up @@ -164,3 +164,4 @@ def setup_parser():

args = parser.parse_args()
main(args)

52 changes: 33 additions & 19 deletions dataset/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import toml
import voluptuous
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema, Optional as VOptional

from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset

Expand All @@ -34,6 +34,7 @@ class BaseDatasetParams:
num_repeats: int = 1
cache_directory: Optional[str] = None
debug_dataset: bool = False
is_val: bool = False


@dataclass
Expand Down Expand Up @@ -65,7 +66,8 @@ class DatasetGroupBlueprint:

@dataclass
class Blueprint:
dataset_group: DatasetGroupBlueprint
train_dataset_group: DatasetGroupBlueprint
val_dataset_group: DatasetGroupBlueprint


class ConfigSanitizer:
Expand Down Expand Up @@ -135,12 +137,11 @@ def validate_flex_dataset(dataset_config: dict):
self.general_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
)
self.user_config_validator = Schema(
{
"general": self.general_schema,
"datasets": [self.dataset_schema],
}
)
self.user_config_validator = Schema({
"general": self.general_schema,
"datasets": [self.dataset_schema],
"val_datasets": VOptional([self.dataset_schema]),
})
self.argparse_schema = self.__merge_dict(
self.ARGPARSE_SPECIFIC_SCHEMA,
)
Expand Down Expand Up @@ -187,26 +188,39 @@ def __init__(self, sanitizer: ConfigSanitizer):
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)

argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
general_config = sanitized_user_config.get("general", {})

dataset_blueprints = []
for dataset_config in sanitized_user_config.get("datasets", []):
# Process training datasets: is_val remains False (default)
train_dataset_configs = sanitized_user_config.get("datasets", [])
train_blueprints = []
for dataset_config in train_dataset_configs:
is_image_dataset = "target_frames" not in dataset_config
if is_image_dataset:
dataset_params_klass = ImageDatasetParams
else:
dataset_params_klass = VideoDatasetParams
dataset_params_klass = ImageDatasetParams if is_image_dataset else VideoDatasetParams
params = self.generate_params_by_fallbacks(
dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params]
)
# is_val defaults to False in the dataclass so nothing special is needed
train_blueprints.append(DatasetBlueprint(is_image_dataset, params))

# Process validation datasets: mark them as validation.
val_dataset_configs = sanitized_user_config.get("val_datasets", [])
val_blueprints = []
for dataset_config in val_dataset_configs:
is_image_dataset = "target_frames" not in dataset_config
dataset_params_klass = ImageDatasetParams if is_image_dataset else VideoDatasetParams
params = self.generate_params_by_fallbacks(
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params]
)
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
params.is_val = True # mark as validation
val_blueprints.append(DatasetBlueprint(is_image_dataset, params))

dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
train_dataset_group_blueprint = DatasetGroupBlueprint(train_blueprints)
val_dataset_group_blueprint = DatasetGroupBlueprint(val_blueprints)

return Blueprint(dataset_group_blueprint)
return Blueprint(train_dataset_group=train_dataset_group_blueprint, val_dataset_group=val_dataset_group_blueprint)

@staticmethod
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
Expand Down
85 changes: 84 additions & 1 deletion dataset/dataset_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,87 @@ The metadata with .json file will be supported in the near future.



-->
-->

────────────────────────────────────────────────────────────────────────
UPDATED DOCUMENTATION: INCLUDING A VALIDATION DATASET
────────────────────────────────────────────────────────────────────────

1. Overview of “val_datasets”
Just like you can have multiple datasets for training under the "[[datasets]]" key, you can also specify validation datasets under "[[val_datasets]]". The syntax and options for each validation dataset are exactly the same as for training. The script will:
• Load and cache your validation datasets the same way it does for training.
• Periodically compute a validation loss across these datasets (for example, once per epoch).
• Log the validation loss so you can monitor for over- or under-fitting.

2. Example TOML Configuration with Validation Datasets

Below is a minimal example that shows both training and validation datasets. Notice that “[[datasets]]” is used for training datasets, while “[[val_datasets]]” is reserved for any validation sets you want to include.

--------------------------------------------------------------------------------
[general]
caption_extension = ".txt"
batch_size = 1
enable_bucket = true
bucket_no_upscale = false

# PRIMARY TRAIN DATASETS
[[datasets]]
resolution = [640, 480]
video_directory = "path/to/training/video"
cache_directory = "path/to/cache/training/video"
frame_extraction = "head"
target_frames = [48]

[[datasets]]
resolution = [640, 480]
image_directory = "path/to/training/image"
cache_directory = "path/to/cache/training/image"

# ... you can add more [[datasets]] blocks if you have more training subsets ...

# VALIDATION DATASETS
[[val_datasets]]
resolution = [640, 480]
video_directory = "path/to/validation/video"
cache_directory = "path/to/cache/validation/video"
frame_extraction = "head"
target_frames = [48]

[[val_datasets]]
resolution = [640, 480]
image_directory = "path/to/validation/image"
cache_directory = "path/to/cache/validation/image"

# ... you can add more [[val_datasets]] blocks if you have more validation subsets ...
--------------------------------------------------------------------------------

Notes on usage:
• The script will treat the “[[datasets]]” entries as training data and “[[val_datasets]]” as validation data. Both sets can be a mix of images or videos.
• Each dataset or validation dataset must have a unique “cache_directory” to avoid overwriting latents or text-encoder caches.
• All of the same parameters (resolution, caption_extension, num_repeats, batch_size, etc.) work in exactly the same way under “val_datasets” as they do under “datasets.”

3. Running the Script with Validation
Once you have listed your training and validation datasets, you can run the training script as normal. For example:

--------------------------------------------------------------------------------
accelerate launch hv_train_network.py \
--dit path/to/DiT/model \
--dataset_config path/to/config_with_val.toml \
--max_train_epochs 20 \
... other arguments ...
--------------------------------------------------------------------------------

During training, the script will:
• Load and batch the training datasets.
• Perform training epochs, computing training loss.
• After each epoch (or at a specified interval), it will compute the validation loss using all “val_datasets.”
• Log the validation performance (for example, “val_loss=...”) to TensorBoard and/or WandB if logging is enabled.

────────────────────────────────────────────────────────────────────────
ADDITIONAL TIPS
────────────────────────────────────────────────────────────────────────
• If you do not wish to do any validation, simply omit the “[[val_datasets]]” sections.
• You can add multiple validation blocks, each pointing to different image or video folders or JSONL metadata.
• The script merges all validation datasets into a single DataLoader when it computes validation loss.

With these changes to your config file, you can systematically evaluate your model’s performance after each epoch (or some other schedule) by leveraging the “val_datasets” blocks.
8 changes: 6 additions & 2 deletions dataset/image_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def __init__(
bucket_no_upscale: bool = False,
cache_directory: Optional[str] = None,
debug_dataset: bool = False,
is_val: bool = False
):
self.resolution = resolution
self.caption_extension = caption_extension
Expand All @@ -757,6 +758,7 @@ def __init__(
self.bucket_no_upscale = bucket_no_upscale
self.cache_directory = cache_directory
self.debug_dataset = debug_dataset
self.is_val = is_val
self.seed = None
self.current_epoch = 0

Expand Down Expand Up @@ -910,9 +912,10 @@ def __init__(
image_jsonl_file: Optional[str] = None,
cache_directory: Optional[str] = None,
debug_dataset: bool = False,
is_val: bool = False
):
super(ImageDataset, self).__init__(
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset, is_val
)
self.image_directory = image_directory
self.image_jsonl_file = image_jsonl_file
Expand Down Expand Up @@ -1086,9 +1089,10 @@ def __init__(
video_jsonl_file: Optional[str] = None,
cache_directory: Optional[str] = None,
debug_dataset: bool = False,
is_val: bool = False
):
super(VideoDataset, self).__init__(
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset, is_val
)
self.video_directory = video_directory
self.video_jsonl_file = video_jsonl_file
Expand Down
14 changes: 12 additions & 2 deletions hunyuan_model/token_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,18 @@ def forward(
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting a crash on my 4090 about type casting I think, this resolved it.

context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
if x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# do the operation in a safer fallback type, e.g. float16 (or bf16)
safe_x = x.float() # from float8 → float
safe_mask = mask.float().unsqueeze(-1)
numerator = (safe_x * safe_mask).sum(dim=1)
denominator = safe_mask.sum(dim=1).clamp_min(1e-8) # avoid div-by-zero
out = numerator / denominator
context_aware_representations = out.to(x.dtype) # cast back to float8
else:
# the old logic for other dtypes
mask_float = mask.float().unsqueeze(-1)
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations

Expand Down
Loading