Skip to content

Conversation

@NSFW-API
Copy link
Contributor

@NSFW-API NSFW-API commented Jan 26, 2025

• Introduces separate train/val dataset groups in the blueprint config and uses them for both cache_latents.py and cache_text_encoder_outputs.py.
• Extends hv_train_network.py with a validate() function that runs on the val_dataset_group each epoch, computes MSE loss, and logs val_loss via Accelerate (e.g., accelerator.log(...)).
• Refactors config_utils.py to return separate train_dataset_group and val_dataset_group, combining them when needed (all_datasets).
• Adds optional float8 fallback handling in token_refiner.py, safely casting float8 → float for calculations and then back to float8.
• Adjusts cache skipping/keeping logic to handle old cache files, and changes the code to enumerate all_datasets instead of only train datasets.
• Overall, this PR makes it possible to run a distinct validation pass each epoch, log validation performance, and unify caching for both train and val.

Training run 1
Screenshot 2025-01-26 093517

Training run 2
Screenshot 2025-01-27 070937

@NSFW-API NSFW-API force-pushed the validation-dataset-stable-loss branch from 27a633a to 3bc1a71 Compare January 26, 2025 22:54
@kohya-ss
Copy link
Owner

kohya-ss commented Jan 26, 2025

Thank you! I think this is a great starting point.

As for the dataset settings, just an idea, how about adding the is_val (or is_validation) attribute to the dataset? This might help minimize modifications to cache_latents.py etc.

@NSFW-API NSFW-API marked this pull request as ready for review January 28, 2025 22:44
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting a crash on my 4090 about type casting I think, this resolved it.

@NSFW-API
Copy link
Contributor Author

Thank you! I think this is a great starting point.

As for the dataset settings, just an idea, how about adding the is_val (or is_validation) attribute to the dataset? This might help minimize modifications to cache_latents.py etc.

I missed this earlier. Will take a look at what it would involve but I don't see why that wouldn't work.

@NSFW-API
Copy link
Contributor Author

NSFW-API commented Feb 4, 2025

Thank you! I think this is a great starting point.

As for the dataset settings, just an idea, how about adding the is_val (or is_validation) attribute to the dataset? This might help minimize modifications to cache_latents.py etc.

I've implemented this change, no longer needing to modify the caching scripts. Good call out!

@NSFW-API NSFW-API marked this pull request as draft February 5, 2025 02:48
@NSFW-API
Copy link
Contributor Author

NSFW-API commented Feb 5, 2025

After re-testing, looks like the cache_latents.py script is failing. It might need some changes to accommodate the is_val change after all. Will resubmit once I make sure all parts are working as expected.

@NSFW-API NSFW-API marked this pull request as ready for review February 5, 2025 04:17
@NSFW-API
Copy link
Contributor Author

NSFW-API commented Feb 5, 2025

Ok, just needed a small change to those caching files.

@Enyakk
Copy link

Enyakk commented Feb 9, 2025

I am trying out the PR and it looks like great instrument to add to the toolkit. However two changes, if practical, would enhance the effectiveness:

  • Large datasets might only have 1-2 epochs during the entirety of the training run so a steps_validate option would be appreciated
  • Multiple named validation datasets in tensorboard: Training contents may have multiple goals. Being able to measure each goal against a defined validation dataset (ie. character A, concept B, regularization C) and display the current validation state would enhance productivity even further.

@NSFW-API NSFW-API marked this pull request as draft February 9, 2025 19:30
@NSFW-API
Copy link
Contributor Author

NSFW-API commented Feb 9, 2025

I got reports that this isn't working when you don't provide a validation dataset, which it was previously as validation should of course be optional. Will see why that broke with the latest commit, also it seems I need to add support for JSON configs in addition to TOML. I missed that there are a variety of ways to specify datasets.

@Sarania Sarania mentioned this pull request Apr 9, 2025
@niceguy4
Copy link

Any opportunities with revisiting this implementation? Validation seems to be an important element for verifying training. It appears to work well when used in conjunction with the tensor board loss chart.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants