-
-
Notifications
You must be signed in to change notification settings - Fork 166
Add Validation Splits and Logging, Refactor Dataset Blueprints, and Improve Float8 Support #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
27a633a to
3bc1a71
Compare
|
Thank you! I think this is a great starting point. As for the dataset settings, just an idea, how about adding the |
| if mask is None: | ||
| context_aware_representations = x.mean(dim=1) | ||
| else: | ||
| mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was getting a crash on my 4090 about type casting I think, this resolved it.
I missed this earlier. Will take a look at what it would involve but I don't see why that wouldn't work. |
I've implemented this change, no longer needing to modify the caching scripts. Good call out! |
|
After re-testing, looks like the cache_latents.py script is failing. It might need some changes to accommodate the |
|
Ok, just needed a small change to those caching files. |
… new is_val dataset
|
I am trying out the PR and it looks like great instrument to add to the toolkit. However two changes, if practical, would enhance the effectiveness:
|
|
I got reports that this isn't working when you don't provide a validation dataset, which it was previously as validation should of course be optional. Will see why that broke with the latest commit, also it seems I need to add support for JSON configs in addition to TOML. I missed that there are a variety of ways to specify datasets. |
|
Any opportunities with revisiting this implementation? Validation seems to be an important element for verifying training. It appears to work well when used in conjunction with the tensor board loss chart. |
• Introduces separate train/val dataset groups in the blueprint config and uses them for both cache_latents.py and cache_text_encoder_outputs.py.
• Extends hv_train_network.py with a validate() function that runs on the val_dataset_group each epoch, computes MSE loss, and logs val_loss via Accelerate (e.g., accelerator.log(...)).
• Refactors config_utils.py to return separate train_dataset_group and val_dataset_group, combining them when needed (all_datasets).
• Adds optional float8 fallback handling in token_refiner.py, safely casting float8 → float for calculations and then back to float8.
• Adjusts cache skipping/keeping logic to handle old cache files, and changes the code to enumerate all_datasets instead of only train datasets.
• Overall, this PR makes it possible to run a distinct validation pass each epoch, log validation performance, and unify caching for both train and val.
Training run 1

Training run 2
