-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add wavelet loss for networks #2037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sd3
Are you sure you want to change the base?
Conversation
train_network.py
Outdated
|
|
||
| wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) | ||
| # Weight the losses as needed | ||
| loss = loss + args.wavelet_loss_alpha * wav_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe do loss = (1.0 - args.wavelet_loss_alpha) * loss + args.wavelet_loss_alpha * wav_loss so its a proper interpolation?
|
You should consider including a weighting/masking scheme for the different levels, I am getting amazing results from it. I've been playing with a prototype of this myself, see #294 (reply in thread) For example, masking the lowpass elements makes it easier to learn subjects and objects without transferring the overall image aesthetic bias. Here is my hacky training code as an example: Maybe we could have a parameter to pass an array of loss weights, one for each level of detail? For example a 1024px image can be decomposed up to 8 levels. My testing gives me interesting results when masking or weighting certain levels differently. |
|
@recris I went through and added inputs for band weighting, allowing a good amount of customization. It was in there previously for SWT but now it is applied to both. The default weights are low so as to allow one to customize how much impact each band has. The alpha is now more like a multiplier instead but should control the weighting via the bands ideally. I added wavelet loss to the logging, which can help when adjusting the band weights to see the wavelet impacts.
I am still working through some trainings of this to get some more examples and comparisons. |
|
Traceback (most recent call last): getting this error when trying it |
|
@EClipXAi Apologies, some newer functionality leaked in there. I have fixed this. Give it another shot and let me know if you have any issues. |
- add full conditional_loss functionality to wavelet loss
- Transforms are separate and abstracted
- Loss now doesn't include LL except the lowest level
- ll_level_threshold allows you to control the level the ll is
used in the loss
- band weights can now be passed in
- rectified flow calculations can be bypassed for experimentation
- Fixed alpha to 1.0 with new weighted bands producing lower loss
cea1930 to
3b949b9
Compare
|
Added QuaterionWaveletTransform derived from https://arxiv.org/abs/2505.00334.
It should work similarly to DWT but with more components (4 vs 1 with DWT) using hilbert filters. It's probably experimental so true recommended values I'm not quite sure yet.
Additionally reworked SWT to work better, and SWT should be more performant now using 1d convolutions.
Added tests for the Wavelet tranforms and Wavelet loss to make sure it's working as expected. |
|
Sharing a good resource for comparing wavelets: https://www.mathworks.com/help/wavelet/gs/introduction-to-the-wavelet-families.html |
|
I've installed the wavelet branch and added the configure wavelet loss from above to my config.yaml. But when I start the training I do not see anything back in the command prompt about wavelet. How do I know it works? Or what am I doing wrong? |
|
@WinodePino During the startup process there will be a section that is like In the config if you do or in the CLI |
Thanks! I found the issue, but now I run into this:
If I start a training with the normal SD3 branch this error won't show up. |
Do you have all of the error? Also which pytorch version are using? |
|
Thanks for you help! At the moment version 2.7.1, but I also tried 2.4 and 2.5 I think.
|
|
it's saying Otherwise it's a bigger issue. |
Didn't work. Because my normal SD Scripts did work I tried to manually replace all necessary changed files and now it starts the training and it shows the section you mentioned during start up. Although I was expecting training to be a lot slower, but it seem to be similar with 4s/it on a 4090. So I am not sure if it is working as expected. |
|
Added new loss calculation, helps a bit with aligning semantically and structurally of the coefficients of the wavelets. Still not getting ideal results (correlation metrics are too flat and learning is too flat) but it is working better now. Complex Wavelet Mutual Information Loss: A Multi-Scale Loss Function for Semantic Segmentation Not using the paper above directly but it added some additional context band_loss = self.loss_fn(pred, target)
# 2. Local structure loss
pred_grad_x = torch.diff(pred, dim=-1)
pred_grad_y = torch.diff(pred, dim=-2)
target_grad_x = torch.diff(target, dim=-1)
target_grad_y = torch.diff(target, dim=-2)
gradient_loss = F.mse_loss(pred_grad_x, target_grad_x) + \
F.mse_loss(pred_grad_y, target_grad_y)
# 3. Global correlation per channel
B, C = pred.shape[:2]
pred_flat = pred.view(B, C, -1)
target_flat = target.view(B, C, -1)
cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=2)
correlation_loss = (1 - cos_sim).mean()
weight = base_weight * band_level_weights.get(weight_key, band_weights[band])
pattern_level_losses += weight.view(-1, 1, 1, 1) * (band_loss +
0.05 * gradient_loss +
0.1 * correlation_loss) # mean stack dimWill make the weighted loss configurable in the future. |
|
To make this be easier to develop, I made a separate library for wavelet-loss. You'll need to install it
I left the code in but new Wavelet code will be done separately. |
|
Should wavelet_loss_type replace loss_type and should I remove it or not? My results so far are letting me generate on higher resolution and it looks nice, but I expected it to be good in the details (creating automotive lora's), but somehow I got way better details with normal LoRa training on a dedistilled model. Am I missing something? I am training on adafactor, 128 rank/alpha.
|
|
Found this PR finally. I hope you don’t mind if I share something relevant to the topic. A guy from an one forum made this wavelet implementation based on many tests (according to his speech) of different configurations. For those who want to manually integrate this into sd-scripts and aren’t quite sure how, the monkey patch is: In train_network: In train_util: a) Define a new loss type: Define arg: b) After the imports, insert: c) Add to the imports: I’ve been experimenting with different optimizers for a week now, and the quality achieved with this implementation is simply excellent. Far better than l1/l2/huber. It works fine with --debiased_estimation_loss, by the way. |
I am trying to test this, but somehow my caching latents gets stuck at 0%, any idea why this could be? And should it work with --loss_type dwt or should I also add something to the config.toml like wavelet_loss = true? Thanks! |
I’m not using the cache at all - I’ll try a test run with it later. Maybe another manual patch is needed for train_util. Did you get any specific errors when it got stuck, or was it just the freeze itself?
It works with --loss_type="dwt" alone, as long as you’ve completed all the steps above. |
Thanks, it is working now, but it doesnt seem to learn anything. |
If the iteration is proceeding at a normal speed, then obviously the problem lies in the too low feature update speed. |











Some papers about wavelets:
https://arxiv.org/abs/2306.00306
https://arxiv.org/abs/2402.19215
https://arxiv.org/abs/2211.16152
https://arxiv.org/abs/2407.12538
https://arxiv.org/abs/2404.11273
Video about Wavelets
https://www.youtube.com/watch?v=jnxqHcObNK4
Example from Training Generative Image Super-Resolution Models by Wavelet-Domain Losses Enables Better Control of Artifacts
An example to showcase what this is suppose to help with.
Install
Usage
Activate Wavelet Loss:
Configure
Configure Wavelet loss:
Recommended starting point:
Huber or l1 are recommended to get the right loss for the signals. Level 2 gets more detail so will be best to capture those fine details. Need wavelet_loss=true to enable wavelet_loss.
CLI
Wavelet families:
We use a custom loss implementation and I just learned about Wavelets like yesterday so we may consider how the approach works and maybe utilize other libraries as well.
I am using flow matching/rectified flow attempts to predict denoised latents to create a better result for wavelets, so might not work as well for some models like SD1.5/SDXL but I am not sure.
Related #2016