Skip to content

Conversation

@rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Apr 8, 2025

Some papers about wavelets:
https://arxiv.org/abs/2306.00306
https://arxiv.org/abs/2402.19215
https://arxiv.org/abs/2211.16152
https://arxiv.org/abs/2407.12538
https://arxiv.org/abs/2404.11273

Video about Wavelets
https://www.youtube.com/watch?v=jnxqHcObNK4

Example from Training Generative Image Super-Resolution Models by Wavelet-Domain Losses Enables Better Control of Artifacts

An example to showcase what this is suppose to help with.

WGSR WGSR
Screenshot 2025-04-07 at 20-08-04 2402 19215v1 pdf Screenshot 2025-04-07 at 20-27-18 2402 19215v1 pdf

Screenshot 2025-04-07 at 18-41-54 Wavelets a mathematical microscope - Invidious

Wavelet examples from: https://www.youtube.com/watch?v=jnxqHcObNK4

Install

pip install PyWavelets git+https://github.com/rockerBOO/wavelet-loss

Usage

Activate Wavelet Loss:

wavelet_loss = true

Configure

Configure Wavelet loss:

wavelet_loss = true
wavelet_loss_metrics = true

wavelet_loss_type = "huber" # l2, l1, smooth_l2, huber. l1 and huber would be recommended
wavelet_loss_transform = "swt" # dwt, swt. swt keeps the spatial details but dwt can be a little more efficient.
wavelet_loss_wavelet = "sym7" # over 100 wavelets, but try db4, sym7. AI toolkit uses haar. 
wavelet_loss_level = 2 # Loss level is how many levels we process. DWT maybe 1-3, SWT 1-2. SWT level 2 focuses on details more. 
wavelet_loss_alpha = 1.0 # How much impact the loss has on the latent loss value
wavelet_loss_energy_ratio = 0.0 # Experimental: Try to match the energy between latents, 0.0 is the default right now.
wavelet_loss_normalize_bands = false # Normalize the bands between 0 and 1. Defaults to true (it might be better as false)
wavelet_loss_band_level_weights = { "ll1" = 0.1, "lh1" = 0.01, "hl1" = 0.01, "hh1" = 0.05, "ll2" = 0.1, "lh2" = 0.01, "hl2" = 0.01, "hh2" = 0.05 } # Set the individual levels band weights. Starts at level 1. If a band is not set it defaults to band_weights below
wavelet_loss_band_weights = { "ll" = 0.1, "lh" = 0.5, "hl" = 0.5, "hh" = 0.3 } # Sets the defaults for the bands. 
wavelet_loss_ll_level_threshold = -1 # level to process the ll at. Low frequency will be similar to the original latent so only need the last levels for that detail.
wavelet_loss_rectified_flow = true # Experimental. Not recommended to change, but toggles rectified flow to get clean latents.

Recommended starting point:

wavelet_loss = true
wavelet_loss_type = "huber" # l2, l1, smooth_l2, huber. l1 and huber would be recommended
wavelet_loss_level = 2 # Loss level is how many levels we process. DWT maybe 1-3, SWT 1-2. SWT level 2 focuses on details more. 

Huber or l1 are recommended to get the right loss for the signals. Level 2 gets more detail so will be best to capture those fine details. Need wavelet_loss=true to enable wavelet_loss.

CLI

--wavelet_loss --wavelet_loss_type huber --wavelet_loss_level 2

Wavelet families:

    haar family: haar
    db family: db1, db2, db3, db4, db5, db6, db7, db8, db9, db10, db11, db12, db13, db14, db15, db16, db17, db18, db19, db20, db21, db22, db23, db24, db25, db26, db27, db28, db29, db30, db31, db32, db33, db34, db35, db36, db37, db38
    sym family: sym2, sym3, sym4, sym5, sym6, sym7, sym8, sym9, sym10, sym11, sym12, sym13, sym14, sym15, sym16, sym17, sym18, sym19, sym20
    coif family: coif1, coif2, coif3, coif4, coif5, coif6, coif7, coif8, coif9, coif10, coif11, coif12, coif13, coif14, coif15, coif16, coif17
    bior family: bior1.1, bior1.3, bior1.5, bior2.2, bior2.4, bior2.6, bior2.8, bior3.1, bior3.3, bior3.5, bior3.7, bior3.9, bior4.4, bior5.5, bior6.8
    rbio family: rbio1.1, rbio1.3, rbio1.5, rbio2.2, rbio2.4, rbio2.6, rbio2.8, rbio3.1, rbio3.3, rbio3.5, rbio3.7, rbio3.9, rbio4.4, rbio5.5, rbio6.8
    dmey family: dmey
    gaus family: gaus1, gaus2, gaus3, gaus4, gaus5, gaus6, gaus7, gaus8
    mexh family: mexh
    morl family: morl
    cgau family: cgau1, cgau2, cgau3, cgau4, cgau5, cgau6, cgau7, cgau8
    shan family: shan
    fbsp family: fbsp
    cmor family: cmor

We use a custom loss implementation and I just learned about Wavelets like yesterday so we may consider how the approach works and maybe utilize other libraries as well.

I am using flow matching/rectified flow attempts to predict denoised latents to create a better result for wavelets, so might not work as well for some models like SD1.5/SDXL but I am not sure.

Related #2016

@rockerBOO rockerBOO changed the title Add wavelet loss Add wavelet loss for networks Apr 8, 2025
train_network.py Outdated

wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
# Weight the losses as needed
loss = loss + args.wavelet_loss_alpha * wav_loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe do loss = (1.0 - args.wavelet_loss_alpha) * loss + args.wavelet_loss_alpha * wav_loss so its a proper interpolation?

@recris
Copy link

recris commented Apr 11, 2025

You should consider including a weighting/masking scheme for the different levels, I am getting amazing results from it.

I've been playing with a prototype of this myself, see #294 (reply in thread)

For example, masking the lowpass elements makes it easier to learn subjects and objects without transferring the overall image aesthetic bias.

Here is my hacky training code as an example:

use_wavelet_loss = True
        wavelet_loss_ratio = 0.98

        huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
        loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)

        if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
            # not applying the mask by multiplying the loss, instead this scales the gradients directly
            # it enables using other loss functions below without caring about having to worry
            # about mask being incompatible
             loss, mask = apply_masked_loss(loss, batch)
            mask = torch.nn.functional.interpolate(mask, size=noise_pred.shape[2:], mode="area")
            noise_pred.register_hook(lambda grad: grad * mask.to(grad.dtype))

        loss = loss.mean([1, 2, 3])

        if use_wavelet_loss:
            # custom weighting scheme, this should be parameter
            num_levels = 3
            level_weights = [1.0] * num_levels
            lowpass_weight = 0.0
            level_weights[1] = 0.0
            level_weights[2] = 0.0

            # this requires pytorch-wavelets to be installed
            dwt = DWTForward(J=num_levels, mode='zero', wave='haar').to(device=accelerator.device, dtype=vae_dtype)

            model_pred_xl, model_pred_xh = dwt(noise_pred)
            target_xl, target_xh = dwt(target)

            # compute lowpass loss
            wt_loss = train_util.conditional_loss(model_pred_xl.float(), target_xl.float(), args.loss_type, "none", huber_c)
            wt_loss = wt_loss * lowpass_weight

            # compute loss for each band
            for lvl, (p, t) in enumerate(zip(reversed(model_pred_xh), reversed(target_xh))):
                l = train_util.conditional_loss(p.float(), t.float(), args.loss_type, "none", huber_c)
                l = l * level_weights[lvl]

                l_xlh, l_xhl, l_xhh = torch.unbind(l, dim=2)

                wt_loss = torch.cat((
                    torch.cat((wt_loss, l_xlh), dim=3),
                    torch.cat((l_xhl, l_xhh), dim=3)),
                    dim=2)

            wt_loss = wt_loss.mean([1, 2, 3])

            loss = wavelet_loss_ratio * wt_loss + (1 - wavelet_loss_ratio) * loss

Maybe we could have a parameter to pass an array of loss weights, one for each level of detail? For example a 1024px image can be decomposed up to 8 levels. My testing gives me interesting results when masking or weighting certain levels differently.

@rockerBOO
Copy link
Contributor Author

@recris I went through and added inputs for band weighting, allowing a good amount of customization. It was in there previously for SWT but now it is applied to both. The default weights are low so as to allow one to customize how much impact each band has. The alpha is now more like a multiplier instead but should control the weighting via the bands ideally.

I added wavelet loss to the logging, which can help when adjusting the band weights to see the wavelet impacts.
Screenshot 2025-04-12 at 02-18-58 women-flux-kohya-lora Workspace – Weights   Biases

Example Example
c-f1-2025-04-12_00059_ c-f1-2025-04-12_00058_
c-f1-2025-04-12_00028_ c-f1-2025-04-12_00027_
c-f1-2025-04-12_00008_ c-f1-2025-04-12_00024_

I am still working through some trainings of this to get some more examples and comparisons.

@EClipXAi
Copy link

Traceback (most recent call last):
File "/workspace/kohya_ss/sd-scripts/flux_train_network.py", line 559, in
trainer.train(args)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 1473, in train
loss, wav_loss = self.process_batch(
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 496, in process_batch
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/kohya_ss/sd-scripts/library/custom_train_functions.py", line 804, in forward
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 488, in loss_fn
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
TypeError: get_huber_threshold_if_needed() takes 3 positional arguments but 4 were given
Traceback (most recent call last):
File "/workspace/kohya_ss/sd-scripts/flux_train_network.py", line 559, in
trainer.train(args)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 1473, in train
loss, wav_loss = self.process_batch(
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 496, in process_batch
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/kohya_ss/sd-scripts/library/custom_train_functions.py", line 804, in forward
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 488, in loss_fn
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
TypeError: get_huber_threshold_if_needed() takes 3 positional arguments but 4 were given

getting this error when trying it

@rockerBOO
Copy link
Contributor Author

@EClipXAi Apologies, some newer functionality leaked in there. I have fixed this. Give it another shot and let me know if you have any issues.

rockerBOO added 13 commits May 4, 2025 18:38
- add full conditional_loss functionality to wavelet loss
- Transforms are separate and abstracted
- Loss now doesn't include LL except the lowest level
  - ll_level_threshold allows you to control the level the ll is
    used in the loss
- band weights can now be passed in
- rectified flow calculations can be bypassed for experimentation
- Fixed alpha to 1.0 with new weighted bands producing lower loss
@rockerBOO rockerBOO force-pushed the network-wavelet-loss branch from cea1930 to 3b949b9 Compare May 4, 2025 22:58
@rockerBOO
Copy link
Contributor Author

rockerBOO commented May 5, 2025

Added QuaterionWaveletTransform derived from https://arxiv.org/abs/2505.00334.

wavelet_loss_transform = "qwt"

wavelet_loss = true
wavelet_loss_type = "huber"
wavelet_loss_transform = "qwt"
wavelet_loss_wavelet = "sym7"
wavelet_loss_level = 3
wavelet_loss_alpha = 1
wavelet_loss_band_weights = { "ll" = 0.25, "lh" = 1.0, "hl" = 1.0, "hh" = 1.0 }
wavelet_loss_quaternion_component_weights = { "r" = 0.25, "i" = 0.5, "j" = 0.5, "k" = 0.5 }
wavelet_loss_ll_level_threshold = -1

It should work similarly to DWT but with more components (4 vs 1 with DWT) using hilbert filters. It's probably experimental so true recommended values I'm not quite sure yet.

wavelet_loss_quaternion_component_weights keys might change to r, x, y, xy but will mention when/if it changes.

Additionally reworked SWT to work better, and SWT should be more performant now using 1d convolutions.

wavelet_loss_transform wasn't properly being used before so now swt and qwt should work where it was only using DWT before.

Added tests for the Wavelet tranforms and Wavelet loss to make sure it's working as expected.

67372a added a commit to 67372a/sd-scripts that referenced this pull request May 13, 2025
@67372a
Copy link

67372a commented May 25, 2025

Sharing a good resource for comparing wavelets:

https://www.mathworks.com/help/wavelet/gs/introduction-to-the-wavelet-families.html

@WinodePino
Copy link

I've installed the wavelet branch and added the configure wavelet loss from above to my config.yaml. But when I start the training I do not see anything back in the command prompt about wavelet. How do I know it works? Or what am I doing wrong?

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Jul 7, 2025

@WinodePino During the startup process there will be a section that is like

Wavelet Loss:
	Level: 2
	Alpha: 1.0
	Transform: swt
	Wavelet: sym7
	LL level threshold: -1

In the config if you do

wavelet_loss_metrics = true

or in the CLI --wavelet_loss_metrics it will produce metrics that will be logged (like to tensorboard, wandb)

@WinodePino
Copy link

@WinodePino During the startup process there will be a section that is like

Wavelet Loss:
	Level: 2
	Alpha: 1.0
	Transform: swt
	Wavelet: sym7
	LL level threshold: -1

In the config if you do

wavelet_loss_metrics = true

or in the CLI --wavelet_loss_metrics it will produce metrics that will be logged (like to tensorboard, wandb)

Thanks! I found the issue, but now I run into this:

AttributeError: module 'torch' has no attribute 'float8_e4m3fnuz'

If I start a training with the normal SD3 branch this error won't show up.

@rockerBOO
Copy link
Contributor Author

Thanks! I found the issue, but now I run into this:

AttributeError: module 'torch' has no attribute 'float8_e4m3fnuz'

If I start a training with the normal SD3 branch this error won't show up.

Do you have all of the error? Also which pytorch version are using?

@WinodePino
Copy link

Thanks for you help! At the moment version 2.7.1, but I also tried 2.4 and 2.5 I think.

INFO Loaded Flux: <All keys matched successfully> flux_utils.py:137Traceback (most recent call last): File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\flux_train_network.py", line 559, in <module> trainer.train(args) File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\train_network.py", line 648, in train model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\flux_train_network.py", line 103, in load_target_model if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\venv\lib\site-packages\torch\__init__.py", line 1833, in __getattr__ raise AttributeError(f"module '{__name__}' has no attribute '{name}'") AttributeError: module 'torch' has no attribute 'float8_e4m3fnuz' Traceback (most recent call last): File "C:\Users\W_vel\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "C:\Users\W_vel\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 86, in _run_code exec(code, run_globals) File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\venv\Scripts\accelerate.exe\__main__.py", line 7, in <module> File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\venv\lib\site-packages\accelerate\commands\accelerate_cli.py", line 48, in main args.func(args) File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\venv\lib\site-packages\accelerate\commands\launch.py", line 1106, in launch_command simple_launcher(args) File "B:\02_AI\sd-scripts-network-wavelet-loss\sd-scripts-network-wavelet-loss\venv\lib\site-packages\accelerate\commands\launch.py", line 704, in simple_launcher raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) subprocess.CalledProcessError: Command '['B:\\02_AI\\sd-scripts-network-wavelet-loss\\sd-scripts-network-wavelet-loss\\venv\\Scripts\\python.exe', 'B:/02_AI/sd-scripts-network-wavelet-loss/sd-scripts-network-wavelet-loss/flux_train_network.py', '--config_file', 'B:/02_AI/sd-scripts-network-wavelet-loss/sd-scripts-network-wavelet-loss\\config/config_lora-20250705-130432.toml']' returned non-zero exit status 1.

@rockerBOO
Copy link
Contributor Author

it's saying torch.float8_e4m3fnuz isn't a attribute which is likely a pytorch version issue. With the venv enabled do

python -c "import torch; print(torch.__version__)"

Otherwise it's a bigger issue.

@WinodePino
Copy link

it's saying torch.float8_e4m3fnuz isn't a attribute which is likely a pytorch version issue. With the venv enabled do

python -c "import torch; print(torch.__version__)"

Otherwise it's a bigger issue.

Didn't work. Because my normal SD Scripts did work I tried to manually replace all necessary changed files and now it starts the training and it shows the section you mentioned during start up.

Although I was expecting training to be a lot slower, but it seem to be similar with 4s/it on a 4090. So I am not sure if it is working as expected.

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Jul 15, 2025

Added new loss calculation, helps a bit with aligning semantically and structurally of the coefficients of the wavelets. Still not getting ideal results (correlation metrics are too flat and learning is too flat) but it is working better now.

Complex Wavelet Mutual Information Loss: A Multi-Scale Loss Function for Semantic Segmentation

Not using the paper above directly but it added some additional context

                    band_loss = self.loss_fn(pred, target)

                    # 2. Local structure loss
                    pred_grad_x = torch.diff(pred, dim=-1)
                    pred_grad_y = torch.diff(pred, dim=-2)  
                    target_grad_x = torch.diff(target, dim=-1)
                    target_grad_y = torch.diff(target, dim=-2)

                    gradient_loss = F.mse_loss(pred_grad_x, target_grad_x) + \
                                   F.mse_loss(pred_grad_y, target_grad_y)

                    # 3. Global correlation per channel
                    B, C = pred.shape[:2]
                    pred_flat = pred.view(B, C, -1)
                    target_flat = target.view(B, C, -1)

                    cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=2)
                    correlation_loss = (1 - cos_sim).mean()

                    weight = base_weight * band_level_weights.get(weight_key, band_weights[band])
                    pattern_level_losses += weight.view(-1, 1, 1, 1) * (band_loss + 
                        0.05 * gradient_loss + 
                        0.1 * correlation_loss)  # mean stack dim

Will make the weighted loss configurable in the future.

@rockerBOO
Copy link
Contributor Author

Can see our pred is really blank but our target is maybe more accurate. I think this is causing some sort of issue but haven't gotten far enough to figure out why it's like this.

compare pred_hh1

@rockerBOO
Copy link
Contributor Author

To make this be easier to develop, I made a separate library for wavelet-loss. You'll need to install it

pip install git+https://github.com/rockerBOO/wavelet-loss

I left the code in but new Wavelet code will be done separately.

@WinodePino
Copy link

Should wavelet_loss_type replace loss_type and should I remove it or not? My results so far are letting me generate on higher resolution and it looks nice, but I expected it to be good in the details (creating automotive lora's), but somehow I got way better details with normal LoRa training on a dedistilled model. Am I missing something? I am training on adafactor, 128 rank/alpha.

loss_type = "l2" wavelet_loss = true wavelet_loss_metrics = true wavelet_loss_type = "l1" # l2, l1, smooth_l2, huber. l1 and huber would be recommended wavelet_loss_transform = "swt" # dwt, swt. swt keeps the spatial details but dwt can be a little more efficient. wavelet_loss_wavelet = "sym7" # over 100 wavelets, but try db4, sym7. AI toolkit uses haar. wavelet_loss_level = 2 # Loss level is how many levels we process. DWT maybe 1-3, SWT 1-2. SWT level 2 focuses on details more. wavelet_loss_alpha = 1.0 # How much impact the loss has on the latent loss value wavelet_loss_band_level_weights = { "ll1" = 0.1, "lh1" = 0.01, "hl1" = 0.01, "hh1" = 0.05, "ll2" = 0.1, "lh2" = 0.01, "hl2" = 0.01, "hh2" = 0.05 } # Set the individual levels band weights. Starts at level 1 wavelet_loss_band_weights = { "ll" = 0.1, "lh" = 0.01, "hl" = 0.01, "hh" = 0.05 } # Sets the defaults for the bands. Currently need to set all the values or bad things might happen. wavelet_loss_ll_level_threshold = -1 # level to process the ll at. Low frequency will be similar to the original latent so only need the last levels for that detail. wavelet_loss_rectified_flow = true # Experimental. Not recommended to change, but toggles rectified flow to get clean latents.

@rockerBOO
Copy link
Contributor Author

Should wavelet_loss_type replace loss_type and should I remove it or not? My results so far are letting me generate on higher resolution and it looks nice, but I expected it to be good in the details (creating automotive lora's), but somehow I got way better details with normal LoRa training on a dedistilled model. Am I missing something? I am training on adafactor, 128 rank/alpha.

One of the issues with the wavelets on the latents is they are pretty noisy. I'm hoping the newer models use a better VAE that keeps the semantic detail better in the latents.

Image based wavelets
wavelet_transforms_dwt

For example with the SD 1.5 fine tuned VAE:
vae_latent_transforms_swt_haar_L3_b1730333

I make my separate repo for wavelet loss to be able to make these visualizations for confirmation. For Flux it is 16 channels so the resulting image is quite massive to show here, so trying to figure out the best way to approach debugging it.

Because the channels have such large differences in the wavelets they pick up in the latents it might be we look at individual channels for the loss or weight the channels differently.

I want to make a more in-depth writeup about these, as to why this PR is still in limbo. It should work better but I haven't had phenomenal results yet.

If you have some examples of your results, it might be helpful to confer with.

@deGENERATIVE-SQUAD
Copy link

deGENERATIVE-SQUAD commented Jul 25, 2025

Found this PR finally. I hope you don’t mind if I share something relevant to the topic.

A guy from an one forum made this wavelet implementation based on many tests (according to his speech) of different configurations.

@torch.compile(dynamic=True)
def compute_weights(timestep, width, shift):
    timestep = min(1000, 1000 - timestep + shift)
    x = torch.linspace(0, 1000, 1000).cuda()
    func = lambda x, mu, sigma, amplitude: amplitude * torch.exp(-(x - mu)**2 / (2 * sigma**2))
    y = func(x, mu=timestep, sigma=width, amplitude=1.0)
    weights = [y[i].item() for i in [0, 200, 400, 600, 800, 999]]
    return weights`
@torch.compile(dynamic=True)`
def compute_loss(model_pred, target, timesteps):
    loss_levels = []
    for b in range(model_pred.shape[0]):
        pred = model_pred[b, ...].unsqueeze(0)
        tgt = target[b, ...].unsqueeze(0)
        num_levels = math.ceil(math.log2(max(pred.shape[2], pred.shape[3])))
        dwt = DWTForward(J=num_levels, mode='zero', wave="haar").to(device=pred.device, dtype=torch.float)`

        model_pred_xl, model_pred_xh = dwt(pred.float())
        target_xl, target_xh = dwt(tgt.float())
        model_pred_l0 = model_pred_xl.unsqueeze(2)
        target_l0 = target_xl.unsqueeze(2)

        level_weights = compute_weights(timesteps[b], 350, 0)
        for p, t, w in zip(model_pred_xh + [model_pred_l0], target_xh + [target_l0], range(1, num_levels + 1)):
            l = F.mse_loss(p.float(), t.float(), reduction="none")
            loss_levels.append(level_weights[min(num_levels + 1 - w, len(level_weights) - 1)] * l.mean([2,3,4]) / float(2 ** w))

    loss = torch.stack(loss_levels, dim=2).mean([1, 2]).mean()
    return loss

For those who want to manually integrate this into sd-scripts and aren’t quite sure how, the monkey patch is:

In train_network:

huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
        loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c, timesteps)
        if weighting is not None:
            loss = loss * weighting
        if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
            loss = apply_masked_loss(loss, batch)
        loss = loss

In train_util:

a) Define a new loss type:

def conditional_loss(
    model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None, timesteps: Optional[torch.Tensor] = None
):
    """
    NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
    """
    if loss_type == "l2":
        loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
    elif loss_type == "dwt":
        if timesteps is None:
            raise ValueError("Wavelet loss requires `timesteps`")
        loss_levels = []
        for b in range(model_pred.shape[0]):
            pred = model_pred[b, ...].unsqueeze(0)
            tgt = target[b, ...].unsqueeze(0)
            num_levels = math.ceil(math.log2(max(pred.shape[2], pred.shape[3])))
            dwt = DWTForward(J=num_levels, mode='zero', wave="haar").to(device=pred.device, dtype=torch.float)

            model_pred_xl, model_pred_xh = dwt(pred.float())
            target_xl, target_xh = dwt(tgt.float())
            model_pred_l0 = model_pred_xl.unsqueeze(2)
            target_l0 = target_xl.unsqueeze(2)

            level_weights = compute_weights(timesteps[b], 350, 0)
            for p, t, w in zip(model_pred_xh + [model_pred_l0], target_xh + [target_l0], range(1, num_levels + 1)):
                l = F.mse_loss(p.float(), t.float(), reduction="none")
                weight_idx = min(num_levels + 1 - w, len(level_weights) - 1)
                loss_levels.append(level_weights[weight_idx] * l.mean([2, 3, 4]))

        loss = torch.stack(loss_levels, dim=0).mean()
    elif loss_type == "l1":
        loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
    elif loss_type == "huber":
        if huber_c is None:
            raise NotImplementedError("huber_c not implemented correctly")
        huber_c = huber_c.view(-1, 1, 1, 1)
        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
        if reduction == "mean":
            loss = torch.mean(loss)
        elif reduction == "sum":
            loss = torch.sum(loss)
    elif loss_type == "smooth_l1":
        if huber_c is None:
            raise NotImplementedError("huber_c not implemented correctly")
        huber_c = huber_c.view(-1, 1, 1, 1)
        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
        if reduction == "mean":
            loss = torch.mean(loss)
        elif reduction == "sum":
            loss = torch.sum(loss)
    else:
        raise NotImplementedError(f"Unsupported Loss Type: {loss_type}")
    return loss

Define arg:

parser.add_argument(
        "--loss_type",
        type=str,
        default="l2",
        choices=["l1", "l2", "huber", "smooth_l1", "dwt"],
        help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
    )

b) After the imports, insert:

def compute_weights(timestep, width, shift):
    timestep = min(1000, 1000 - timestep + shift)
    x = torch.linspace(0, 1000, 1000).cuda()
    func = lambda x, mu, sigma, amplitude: amplitude * torch.exp(-(x - mu)**2 / (2 * sigma**2))
    y = func(x, mu=timestep, sigma=width, amplitude=1.0)
    weights = [y[i].item() for i in [0, 200, 400, 600, 800, 999]]
    return weights

c) Add to the imports:

import torch.nn.functional as F  
from pytorch_wavelets import DWTForward

I’ve been experimenting with different optimizers for a week now, and the quality achieved with this implementation is simply excellent. Far better than l1/l2/huber. It works fine with --debiased_estimation_loss, by the way.

@WinodePino
Copy link

Found this PR finally. I hope you don’t mind if I share something relevant to the topic.

A guy from an one forum made this wavelet implementation based on many tests (according to his speech) of different configurations.

@torch.compile(dynamic=True)
def compute_weights(timestep, width, shift):
    timestep = min(1000, 1000 - timestep + shift)
    x = torch.linspace(0, 1000, 1000).cuda()
    func = lambda x, mu, sigma, amplitude: amplitude * torch.exp(-(x - mu)**2 / (2 * sigma**2))
    y = func(x, mu=timestep, sigma=width, amplitude=1.0)
    weights = [y[i].item() for i in [0, 200, 400, 600, 800, 999]]
    return weights`
@torch.compile(dynamic=True)`
def compute_loss(model_pred, target, timesteps):
    loss_levels = []
    for b in range(model_pred.shape[0]):
        pred = model_pred[b, ...].unsqueeze(0)
        tgt = target[b, ...].unsqueeze(0)
        num_levels = math.ceil(math.log2(max(pred.shape[2], pred.shape[3])))
        dwt = DWTForward(J=num_levels, mode='zero', wave="haar").to(device=pred.device, dtype=torch.float)`

        model_pred_xl, model_pred_xh = dwt(pred.float())
        target_xl, target_xh = dwt(tgt.float())
        model_pred_l0 = model_pred_xl.unsqueeze(2)
        target_l0 = target_xl.unsqueeze(2)

        level_weights = compute_weights(timesteps[b], 350, 0)
        for p, t, w in zip(model_pred_xh + [model_pred_l0], target_xh + [target_l0], range(1, num_levels + 1)):
            l = F.mse_loss(p.float(), t.float(), reduction="none")
            loss_levels.append(level_weights[min(num_levels + 1 - w, len(level_weights) - 1)] * l.mean([2,3,4]) / float(2 ** w))

    loss = torch.stack(loss_levels, dim=2).mean([1, 2]).mean()
    return loss

For those who want to manually integrate this into sd-scripts and aren’t quite sure how, the monkey patch is:

In train_network:

huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
        loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c, timesteps)
        if weighting is not None:
            loss = loss * weighting
        if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
            loss = apply_masked_loss(loss, batch)
        loss = loss

In train_util:

a) Define a new loss type:

def conditional_loss(
    model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None, timesteps: Optional[torch.Tensor] = None
):
    """
    NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
    """
    if loss_type == "l2":
        loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
    elif loss_type == "dwt":
        if timesteps is None:
            raise ValueError("Wavelet loss requires `timesteps`")
        loss_levels = []
        for b in range(model_pred.shape[0]):
            pred = model_pred[b, ...].unsqueeze(0)
            tgt = target[b, ...].unsqueeze(0)
            num_levels = math.ceil(math.log2(max(pred.shape[2], pred.shape[3])))
            dwt = DWTForward(J=num_levels, mode='zero', wave="haar").to(device=pred.device, dtype=torch.float)

            model_pred_xl, model_pred_xh = dwt(pred.float())
            target_xl, target_xh = dwt(tgt.float())
            model_pred_l0 = model_pred_xl.unsqueeze(2)
            target_l0 = target_xl.unsqueeze(2)

            level_weights = compute_weights(timesteps[b], 350, 0)
            for p, t, w in zip(model_pred_xh + [model_pred_l0], target_xh + [target_l0], range(1, num_levels + 1)):
                l = F.mse_loss(p.float(), t.float(), reduction="none")
                weight_idx = min(num_levels + 1 - w, len(level_weights) - 1)
                loss_levels.append(level_weights[weight_idx] * l.mean([2, 3, 4]))

        loss = torch.stack(loss_levels, dim=0).mean()
    elif loss_type == "l1":
        loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
    elif loss_type == "huber":
        if huber_c is None:
            raise NotImplementedError("huber_c not implemented correctly")
        huber_c = huber_c.view(-1, 1, 1, 1)
        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
        if reduction == "mean":
            loss = torch.mean(loss)
        elif reduction == "sum":
            loss = torch.sum(loss)
    elif loss_type == "smooth_l1":
        if huber_c is None:
            raise NotImplementedError("huber_c not implemented correctly")
        huber_c = huber_c.view(-1, 1, 1, 1)
        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
        if reduction == "mean":
            loss = torch.mean(loss)
        elif reduction == "sum":
            loss = torch.sum(loss)
    else:
        raise NotImplementedError(f"Unsupported Loss Type: {loss_type}")
    return loss

Define arg:

parser.add_argument(
        "--loss_type",
        type=str,
        default="l2",
        choices=["l1", "l2", "huber", "smooth_l1", "dwt"],
        help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
    )

b) After the imports, insert:

def compute_weights(timestep, width, shift):
    timestep = min(1000, 1000 - timestep + shift)
    x = torch.linspace(0, 1000, 1000).cuda()
    func = lambda x, mu, sigma, amplitude: amplitude * torch.exp(-(x - mu)**2 / (2 * sigma**2))
    y = func(x, mu=timestep, sigma=width, amplitude=1.0)
    weights = [y[i].item() for i in [0, 200, 400, 600, 800, 999]]
    return weights

c) Add to the imports:

import torch.nn.functional as F  
from pytorch_wavelets import DWTForward

I’ve been experimenting with different optimizers for a week now, and the quality achieved with this implementation is simply excellent. Far better than l1/l2/huber. It works fine with --debiased_estimation_loss, by the way.

I am trying to test this, but somehow my caching latents gets stuck at 0%, any idea why this could be?

And should it work with --loss_type dwt or should I also add something to the config.toml like wavelet_loss = true?

Thanks!

@deGENERATIVE-SQUAD
Copy link

@WinodePino

I am trying to test this, but somehow my caching latents gets stuck at 0%, any idea why this could be?

I’m not using the cache at all - I’ll try a test run with it later. Maybe another manual patch is needed for train_util. Did you get any specific errors when it got stuck, or was it just the freeze itself?

And should it work with --loss_type dwt or should I also add something to the config.toml like wavelet_loss = true?

It works with --loss_type="dwt" alone, as long as you’ve completed all the steps above.

@WinodePino
Copy link

@WinodePino

I am trying to test this, but somehow my caching latents gets stuck at 0%, any idea why this could be?

I’m not using the cache at all - I’ll try a test run with it later. Maybe another manual patch is needed for train_util. Did you get any specific errors when it got stuck, or was it just the freeze itself?

And should it work with --loss_type dwt or should I also add something to the config.toml like wavelet_loss = true?

It works with --loss_type="dwt" alone, as long as you’ve completed all the steps above.

Thanks, it is working now, but it doesnt seem to learn anything.

@deGENERATIVE-SQUAD
Copy link

@WinodePino

Thanks, it is working now, but it doesnt seem to learn anything.

If the iteration is proceeding at a normal speed, then obviously the problem lies in the too low feature update speed.
Show your full configuration, please.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants