Tips and definitions for newbies to diffusion training #288
Replies: 2 comments 7 replies
-
|
You can add logic of video training vs image training i mean how they actually happens both training and inference |
Beta Was this translation helpful? Give feedback.
-
Really amazing post, thanks for taking the time to explain it all. I have a question in regards to Lora+ - how should one go about adjusting the LR when using loraplus for WAN training? Right now I am using a LR of 0.00005 for my training and its ok. Say I wan to use "loraplus_lr_ratio=4", to achieve similar results would I have to adjust my LR to something like 0,0002? Or is my math wrong here? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Heya! I decided to write this because I find a lot of people need similar questions answered, and while there are a LOT of guides out there a lot of them assume some level of technicality/expertise up front. This post is simply meant to be a breakdown of some common terms and concepts you will encounter when training diffusion models like Wan, Hunyuan, Flux, Stable Diffusion, etc. It's purposefully kept basic and broad to give those just beginning a foundation of knowledge to work with and I try to assume as little prior knowledge as possible other than a basic familiarity with diffusion models. Additionally I've tried to lay the document out so that nothing is referenced before it's explained. I'll probably update this post with more things as I think of them and have time. This is NOT meant to be a guide for training a specific model, rather a collection of knowledge.
Preamble:
You likely know this much but for training you need a dataset - a collection of images/videos showing things you want the model to learn. Along with these you need accompanying captions - descriptions that represent the prompts that would have created those images/videos if they had been created with diffusion(regardless of whether they were or not).
The first thing you should understand is the basic flow of the training process. In very simplified terms, what happens is we generate a sample using the captions provided for training, then compare the provided image/video that caption was for to that sample. The difference is "loss" and then we update our weights to try to learn this loss. Repeat this many times and that's training. Now that you have this basic knowledge, here are some definitions of various training adjacent terms:
Definitions:
LoRA: LoRA ([Lo]w [R]ank [A]daptation) is a technology that was created to allow us to personalize/train these models more effectively. You can think of it like a small "sub model" that you temporarily merge to the main model during inference to teach it a new concept. Many people create and share LoRA for all kinds of different models and concepts on sites like CivitAI and others. Each one is specific to the base model it was trained on. There are also advanced types of LoRA such as LyCORIS which is a collection of advanced LoRA methods like LoHA and LoKR as well as DoRA and more. The basic goal with all of these is the same - create a small sub model to learn a new thing - and the different methods are different ways of doing that. They might modify different parts of the model, for instance. Beginners should start with basic LoRA though as it's quite sufficient for most tasks! Functionally you can think of it like a small model that sits between the base model and the output and who's purpose is to modify the base model's output to contain the new concept.
E notation: E notation refers to numbers like 2e-5. It's an alternate form of scientific notation: 2e-5 is 2 * 10^-5 (two times ten to the negative fifth power) or 0.00002 and it's useful for specifying very small values compactly!
Learning Rate: Often shortened to just LR, just like it sounds this is the rate at which we update the model's weights each step to try to learn the thing. Higher will learn faster but too high and you may overshoot - have you ever run so fast you couldn't quite stop where you meant to? It's like that. Or you might cause the model to collapse(burn out). Lower of course is slower. Like many of these parameters the optimal value is a balance - you don't wanna take forever but you wanna get good results. Learning rate is often specified in e notation.
Loss: As previously alluded to, loss is the difference between what the model is currently capable of making and what we want it to be able to make. This is often monitored during training with metrics like "avr_loss" (average loss) and is useful to tell how the training is progressing. A lower loss means we are closer to where we wanna be - the training sample we generated is closer to the example provided in our dataset. Note that MANY things can affect loss it's a broad measurement of how training is going but low loss does not guarantee success it just suggests things are moving in the right direction.
Network Dimension: Also sometimes referred to as "Rank", this is simply determines the size of a LoRA e.g. how many parameters it has and how much space it takes up. The final size is a function of the base model and network dimension. So for instance, a rank 32 LoRA for SD1.5 will be much smaller than one for Flux because SD1.5 is much smaller than Flux. This functionally determines how much your LoRA can learn. Bigger can learn more but has higher risk of concept bleed and plays less well with other LoRA. Smaller can learn less but hopefully only learns what you want and not what you don't. Ideally you wanna use the smallest size that is sufficient. In practice this one can be tricky to tune. Generally you want higher values for smaller models and vice versa.
Network Alpha: This value in combination with network dimension controls a feature called alpha scaling. Because of the way decimal numbers are stored in computers(see
datatypebelow), very small values can be rounded to 0 or otherwise compromised. Alpha scaling is an attempt to fix this by amplifying the weights during training and restoring them during inference. The effect is active when network alpha is less than network dimension and functionally we divide alpha by dimension and take the square root of that result to create scale, so lower network_alpha values create stronger effects. It's utility varies with the model but it tends to more helpful when you haven't got your other parameters fully dialed in - enabling the effect tends to soften the changes to the model and make learning more forgiving but might make getting a perfect result impossible. Disabling it tends to make training more unforgiving but if you get it right the result might be better. Note that alpha scaling is supported by kohya's projects (sdscripts, Musubi Tuner) but not for instance diffusion-pipe and some others. Increased alpha scaling tends to require increasing LR to compensate.LoRAPlus: Internally LoRA have different types of blocks - sometimes referred to as A and B or down and up. Normally these both train at the same LR but research has shown that this is not optimal - ( https://arxiv.org/abs/2402.12354 ). LoRAPlus simply refers to a multiplier applied to the LR of LoRA B/Up* and in practice it helps you reach better convergence (optimal result) more quickly. Good values depend on the model, my personal experience with this is limited to Wan and Hunyuan for which I'd recommend a value of 4. *Naming convention of internal LoRA structures is somewhat inconsistent across projects, this references kohya's projects specifically.
Gradient descent: Imagine you are standing at the top of a tall mountain. At the very bottom of this mountain in a deep valley is a delicious cookie you REALLY want (the optimal training target). But also on and around the mountain there are other shallow areas and valleys. These also contain cookies but /far less tasty ones/. They are local minima - similar to the cookie you are after in some ways but not that perfect cookie. They might be missing the chocolate chips or maybe they were baked too long or have a bite taken out of them (yikes!) Gradient descent refers to algorithms designed to help you descend that mountain and reach the perfect cookie at the bottom with as little effort as possible and without getting stuck in local minima - without getting distracted by suboptimal cookies on the way down 😉 Functionally it's how we feel our way through the training process and decide how much to update our weights by and in which direction when training.
Optimizer: This is the name of an algorithm (such as adamw, adamw8bit, adafactor, came_pytorch) responsible for controlling the gradient descent process during training. Generally you will find one you like and stick with it, and the default is quite sufficient for beginners. Despite the cute cookie analogy, mathematically gradient descent is super complex and non trivial - our subject REALLY wants their cookie RIGHT NOW so it's hard to keep them away from the imperfect ones and we have limited information about the environment around us. As a result there are many different ways of approaching the process and many potential tradeoffs - do we optimize for speed, accuracy, memory usage, etc. Optimizers also have arguments like "weight decay" (which is the percentage of weights we allow the model to forget each round to encourage diversity and thwart overfitting) and more, but generally those are advanced tweaks.
Scheduler: A scheduler is logic for changing the learning rate over time. Constant is like it sounds but there are also Constant with warmup (LR increases from 0 to LR over warmup_steps, then stays constant), Cosine (LR decrease to 0 over the course of the run), Cosine with restarts (Cosine but multiple cycles instead of just one long decrease) and more. What's appropriate depends on the model and what you're trying to achieve. Advanced methods can help you get tighter convergence, but for beginners constant or constant with warmup is fine.
Epoch: An epoch [eh-pok (IPA: ĕp′ək), or sometimes ee-pok (IPA: ē′pŏk″)] is simply one round of batches during training. It's one pass through all the stuff in your dataset. It's one way of specifying how long a training ran/will run, along with steps/iterations. If your dataset has 30 videos and assuming batch_size and num_repeats = 1, then an epoch is 30 steps. If your batch size was 2, it would be 15.
Attention: This is basically the entire concept that makes these models possible. These are algorithms and methods for allowing the model to determine the relative importance of components in a sequence - it's how the model decides what parts of your prompt are important to focus on. Laid out in the landmark Attention is All You Need paper ( https://arxiv.org/abs/1706.03762 ), this is the magic of the transformer. Transformers are a type of machine learning model and examples are HunyuanVideo, Flux, Wan2.1, Stable Diffusion, any version of GPT([G]enerative [P]re-trained [T]ransformer), Gemma2, etc. Diffusion models, LLMs, and more use the transformer architecture for it's powerful, flexible capabilities. More specifically for training, you might need to choose an attention implementation - such as SDPA, Flash attention, Xformers, etc. These are different algorithms that implement attention and what you will want to choose will depend on your hardware. It /shouldn't/ significantly impact your results, but it can affect your speed. Generally some version of flash attention is preferred for training when available.
Seed: A seed is a number fed into a random number generator to condition it's state. Random numbers on computers are actually generated algorithmically(except in extremely specialized cases). A seed allows deterministic randomness (nice oxymoron!) For an example, imagine you have a seed of 4. You ask for a random number and get "5", then again and get "7", then "4" etc. If you close/reset the program and start again with the same seed of 4, the same sequence of 5, 7, 4, etc will come out of the random number generator. But if you give it a different seed, a different sequence emerges. For training it's not a huge deal though, so just pick a number you like such as a birthday or phone number 🙂 Note it IS possible to get a bad or unfortunate seed during training so it's worth changing it about when trying to get better results!
Forward/Backward pass: These are different ways of interacting with the model's weights. The forward pass is the denoising pass which creates images/video. The backward pass is the backpropagation step which updates the models weights to try to learn something new. They are named as such because the forward pass is calculated from first to last layer and the backward from last to first.
Gradient Checkpointing: Gradient checkpointing is a way we reduce memory usage during training. Instead of storing all the gradients during the forward pass we only store some, the others will be recomputed during the backwards pass. Basically an exchange of lower memory usage for lower speed. Gradients are training variables involved in tracking the gradient descent process.
Datatype: "dtype" for short, you might know this but computers normally work in base 2 or binary - 10010110 is a single binary byte for instance. A datatype is a specific way of representing numbers in binary. As it relates to machine learning and diffusion specifically, usually you will be dealing with "floating point" datatypes which store non whole numbers. The "floating" means the decimal point can be in an arbitrary place it could be 3.14 (where it's at position 2) or it could be 124.1456 (where it's at position 4) for instance. Examples of floating point dtypes are fp32 aka float32, fp16 aka float16, bf16 aka bfloat16(b stands for brain), etc. The main thing for beginners to know here is that these are /not/ capable of perfectly representing every possible fractional number - they have limitations in both the highest/lowest value they can store and the precision in which they can do it. The number part of the datatype - the 32 in "float32" - is how many bits are being used to represent the numbers. Generally more bits = more accurate, but uses more memory and (usually) slower. When it comes to fp16 vs bf16 both use 16 bits but those bits are allocated differently - fp16 allocates more to the mantissa which is the name for the bits that store the fractional part. This allows it to store more exact values. bf16 allocates more to the exponent or non fractional part so it can store a larger range of values, but less precisely. Which is better depends on the use case. All models have a datatype they are stored in. This is usually but not always the same datatype their calculations will be done in. Which brings us to...
Quantization: Quantization as it refers to diffusion models are various ways of changing the datatype of the model to something with less bits so that it uses less memory while retaining as much accuracy as possible. Various methods can be used to achieve this. Beginners really just need to know it's a way to lower the amount of memory used in exchange for a minimized loss in precision. Generally quantization methods are classified by the number of bits they store the parameters in and there are lots of different methods. Musubi supports fp8_e4m3fn (8 bits of precision allocated as 1 sign bit 4 exponent bits 3 mantissa bits, the fn stands for "finite numbers" meaning it has no way to represent infinity) and fp8_scaled which uses the same fp8_e4m3fn datatype in a smarter way.
GeMM: Short for [Ge]neral [M]atrix [M]ultiply, matrix multiplications are a massive part of the actual work being done when you inference or train a diffusion model and thus optimizing the algorithm which does these multiplications is highly beneficial as you'll see in the following entries!
fp8_fast: This is the common name for a method utilizing torch._scaled_mm to do fast matrix multiplication in fp8 precision on supported GPUs. Internally it means instead of doing the model's calculations in it's native dtype, we are doing a scaled quantization to fp8 and then doing the math because some GPUs (notably RTX 40xx e.g. Ada Lovelace arch) are VERY fast at this. Fundamentally you are sacrificing a bit of quality to gain a LOT of speed.
fp16_fast/fp16_accumulation: These both refer to ways of accelerating matrix multiplication again. In this case normally we would do the accumulation in fp32 but instead we sacrifice some precision and do it in fp16 to gain a lot of speed. This is another accuracy/speed tradeoff. It's notably beneficial for Wan2.1 models which suffer atypically large quality loss when using fp8_fast. It's slightly less of a boost than fp8_fast but also doesn't sacrifice as much quality. PyTorch only added support for this in 2.7.0 and support in projects is currently somewhat limited(Definitely supported by WanVideoWrapper, Blissful Tuner and possibly others, not yet supported by mainline Musubi.)
[Work in progress]
Beta Was this translation helpful? Give feedback.
All reactions