-
-
Notifications
You must be signed in to change notification settings - Fork 166
Description
Hi, thank you for the excellent codebase — I really appreciate your work on this project.
I think I might have found a potential issue (or perhaps I’m misunderstanding the intended design).
The relevant code is here:
musubi-tuner/src/musubi_tuner/qwen_image_train.py
Lines 315 to 336 in 919d611
| # patch for fused backward pass, adafactor only | |
| if args.fused_backward_pass: | |
| # use fused optimizer for backward pass: other optimizers will be supported in the future | |
| import musubi_tuner.modules.adafactor_fused as adafactor_fused | |
| adafactor_fused.patch_adafactor_fused(optimizer) | |
| for param_group, param_name_group in zip(optimizer.param_groups, param_names): | |
| for parameter, param_name in zip(param_group["params"], param_name_group): | |
| if parameter.requires_grad: | |
| def create_grad_hook(p_name, p_group): | |
| def grad_hook(tensor: torch.Tensor): | |
| if accelerator.sync_gradients and args.max_grad_norm != 0.0: | |
| accelerator.clip_grad_norm_(tensor, args.max_grad_norm) | |
| optimizer.step_param(tensor, p_group) | |
| tensor.grad = None | |
| return grad_hook | |
| parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) | |
My understanding
To implement fused_backward_pass, the current code uses register_post_accumulate_grad_hook to update each parameter right after its gradient is computed, then discard the gradient to reduce VRAM usage.
This seems to be a clever design for memory efficiency.
Potential issue
However, according to the official PyTorch documentation and also this DDP reimplementation reference,
it appears that register_post_accumulate_grad_hook is called before gradients are synchronized across GPUs in DDP.
If that is the case, parameters may be updated before gradient synchronization, meaning that each GPU could start diverging in its model parameters when using fused_backward_pass.
Request for clarification
Could you please confirm whether this understanding is correct?
If I’ve misunderstood how the synchronization or hook timing works in this implementation, I’d really appreciate any clarification.
Thank you again for sharing this great work!