-
-
Notifications
You must be signed in to change notification settings - Fork 166
Fixing multiple GPU Qwen Image Fine tuning training #674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
e7ae593
60efbca
78e88b6
678bef1
78c1409
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,12 +148,10 @@ def prepare_accelerator(args: argparse.Namespace) -> Accelerator: | |
| if torch.cuda.device_count() > 1 | ||
| else None | ||
| ), | ||
| ( | ||
| DistributedDataParallelKwargs( | ||
| gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph | ||
| ) | ||
| if args.ddp_gradient_as_bucket_view or args.ddp_static_graph | ||
| else None | ||
| DistributedDataParallelKwargs( | ||
| find_unused_parameters=True, | ||
| gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, | ||
| static_graph=args.ddp_static_graph | ||
| ), | ||
| ] | ||
| kwargs_handlers = [i for i in kwargs_handlers if i is not None] | ||
|
|
@@ -1881,6 +1879,12 @@ def train(self, args): | |
| else: | ||
| transformer = accelerator.prepare(transformer) | ||
|
|
||
| # Ensure DDP is properly configured for models with unused parameters | ||
| if hasattr(transformer, 'module') and hasattr(transformer.module, 'find_unused_parameters'): | ||
| transformer.module.find_unused_parameters = True | ||
| elif hasattr(transformer, 'find_unused_parameters'): | ||
| transformer.find_unused_parameters = True | ||
|
||
|
|
||
| network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) | ||
| training_model = network | ||
|
|
||
|
|
@@ -2116,7 +2120,8 @@ def remove_model(old_ckpt_name): | |
| # training loop | ||
|
|
||
| # log device and dtype for each model | ||
| logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}") | ||
| unwrapped_transformer = accelerator.unwrap_model(transformer) | ||
| logger.info(f"DiT dtype: {unwrapped_transformer.dtype}, device: {unwrapped_transformer.device}") | ||
|
||
|
|
||
| clean_memory_on_device(accelerator.device) | ||
|
|
||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears that this file has been unintentionally modified. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the PyTorch documentation, specifying
find_unused_parameters=Truewhen it is not necessary will slow down the training:https://docs.pytorch.org/docs/stable/notes/ddp.html#internal-design
Therefore, as with other DDP-related options, it would be preferable to be able to specify it as an argument (for example,
--ddp_find_unused_parameters).