Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/musubi_tuner/hv_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,10 @@ def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
if torch.cuda.device_count() > 1
else None
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
DistributedDataParallelKwargs(
find_unused_parameters=True,
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
),
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
Expand Down Expand Up @@ -897,6 +895,12 @@ def train(self, args):
else:
transformer = accelerator.prepare(transformer)

# Ensure DDP is properly configured for models with unused parameters
if hasattr(transformer, 'module') and hasattr(transformer.module, 'find_unused_parameters'):
transformer.module.find_unused_parameters = True
elif hasattr(transformer, 'find_unused_parameters'):
transformer.find_unused_parameters = True

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

transformer.train()
Expand Down Expand Up @@ -1004,7 +1008,8 @@ def remove_model(old_ckpt_name):
# training loop

# log device and dtype for each model
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
unwrapped_transformer = accelerator.unwrap_model(transformer)
logger.info(f"DiT dtype: {unwrapped_transformer.dtype}, device: {unwrapped_transformer.device}")

clean_memory_on_device(accelerator.device)

Expand Down
19 changes: 12 additions & 7 deletions src/musubi_tuner/hv_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,10 @@ def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
if torch.cuda.device_count() > 1
else None
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
DistributedDataParallelKwargs(
find_unused_parameters=True,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the PyTorch documentation, specifying find_unused_parameters=True when it is not necessary will slow down the training:
https://docs.pytorch.org/docs/stable/notes/ddp.html#internal-design

Therefore, as with other DDP-related options, it would be preferable to be able to specify it as an argument (for example, --ddp_find_unused_parameters).

gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
),
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
Expand Down Expand Up @@ -1881,6 +1879,12 @@ def train(self, args):
else:
transformer = accelerator.prepare(transformer)

# Ensure DDP is properly configured for models with unused parameters
if hasattr(transformer, 'module') and hasattr(transformer.module, 'find_unused_parameters'):
transformer.module.find_unused_parameters = True
elif hasattr(transformer, 'find_unused_parameters'):
transformer.find_unused_parameters = True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be no point in overriding find_unused_parameters here, it will already be True if configured correctly.


network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
training_model = network

Expand Down Expand Up @@ -2116,7 +2120,8 @@ def remove_model(old_ckpt_name):
# training loop

# log device and dtype for each model
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
unwrapped_transformer = accelerator.unwrap_model(transformer)
logger.info(f"DiT dtype: {unwrapped_transformer.dtype}, device: {unwrapped_transformer.device}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a new local variable unwrapped_transformer may prevent garbage collection later, it is better to call it directly: accelerator.unwrap_model(transformer).dtype and accelerator.unwrap_model(transformer).device.


clean_memory_on_device(accelerator.device)

Expand Down
9 changes: 8 additions & 1 deletion src/musubi_tuner/qwen_image_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ def train(self, args):
else:
transformer = accelerator.prepare(transformer)

# Ensure DDP is properly configured for models with unused parameters
if hasattr(transformer, 'module') and hasattr(transformer.module, 'find_unused_parameters'):
transformer.module.find_unused_parameters = True
elif hasattr(transformer, 'find_unused_parameters'):
transformer.find_unused_parameters = True

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
training_model = transformer

Expand Down Expand Up @@ -515,7 +521,8 @@ def remove_model(old_ckpt_name):
# training loop

# log device and dtype for each model
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
unwrapped_transformer = accelerator.unwrap_model(transformer)
logger.info(f"DiT dtype: {unwrapped_transformer.dtype}, device: {unwrapped_transformer.device}")

clean_memory_on_device(accelerator.device)

Expand Down
14 changes: 7 additions & 7 deletions src/musubi_tuner/qwen_image_train_network.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that this file has been unintentionally modified.

Original file line number Diff line number Diff line change
Expand Up @@ -436,20 +436,20 @@ def call_dit(
if is_edit:
model_pred = model_pred[:, :img_seq_len]

# unpack latents
model_pred = qwen_image_utils.unpack_latents(
# flow matching loss - compute loss on raw model output before unpacking
latents = latents.to(device=accelerator.device, dtype=network_dtype)
target = noise - latents

# unpack latents for loss calculation
model_pred_unpacked = qwen_image_utils.unpack_latents(
model_pred,
lat_h * qwen_image_utils.VAE_SCALE_FACTOR,
lat_w * qwen_image_utils.VAE_SCALE_FACTOR,
qwen_image_utils.VAE_SCALE_FACTOR,
)

# flow matching loss
latents = latents.to(device=accelerator.device, dtype=network_dtype)
target = noise - latents

# print(model_pred.dtype, target.dtype)
return model_pred, target
return model_pred_unpacked, target

# endregion model specific

Expand Down