Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/musubi_tuner/hv_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import ast
import asyncio
from datetime import timedelta
Expand Down Expand Up @@ -134,9 +134,11 @@
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
find_unused_parameters=args.ddp_find_unused_parameters,
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
if args.ddp_find_unused_parameters or args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
]
Expand Down Expand Up @@ -1004,7 +1006,7 @@
# training loop

# log device and dtype for each model
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
logger.info(f"DiT dtype: {accelerator.unwrap_model(transformer).dtype}, device: {accelerator.unwrap_model(transformer).device}")

clean_memory_on_device(accelerator.device)

Expand Down Expand Up @@ -1340,6 +1342,11 @@
action="store_true",
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
)
parser.add_argument(
"--ddp_find_unused_parameters",
action="store_true",
help="enable find_unused_parameters for DDP. According to PyTorch docs, specifying True when not necessary will slow down training / DDPでfind_unused_parametersを有効にする。PyTorchのドキュメントによると、不要な場合にTrueを指定すると学習が遅くなる",
)

parser.add_argument(
"--sample_every_n_steps",
Expand Down
13 changes: 10 additions & 3 deletions src/musubi_tuner/hv_train_network.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import ast
import asyncio
from datetime import timedelta
Expand Down Expand Up @@ -150,9 +150,11 @@
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
find_unused_parameters=args.ddp_find_unused_parameters,
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
if args.ddp_find_unused_parameters or args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
]
Expand Down Expand Up @@ -2116,7 +2118,7 @@
# training loop

# log device and dtype for each model
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
logger.info(f"DiT dtype: {accelerator.unwrap_model(transformer).dtype}, device: {accelerator.unwrap_model(transformer).device}")

clean_memory_on_device(accelerator.device)

Expand Down Expand Up @@ -2445,6 +2447,11 @@
action="store_true",
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
)
parser.add_argument(
"--ddp_find_unused_parameters",
action="store_true",
help="enable find_unused_parameters for DDP. According to PyTorch docs, specifying True when not necessary will slow down training / DDPでfind_unused_parametersを有効にする。PyTorchのドキュメントによると、不要な場合にTrueを指定すると学習が遅くなる",
)

parser.add_argument(
"--sample_every_n_steps",
Expand Down
2 changes: 1 addition & 1 deletion src/musubi_tuner/qwen_image_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import argparse
import json
import math
Expand Down Expand Up @@ -515,7 +515,7 @@
# training loop

# log device and dtype for each model
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
logger.info(f"DiT dtype: {accelerator.unwrap_model(transformer).dtype}, device: {accelerator.unwrap_model(transformer).device}")

clean_memory_on_device(accelerator.device)

Expand Down
14 changes: 7 additions & 7 deletions src/musubi_tuner/qwen_image_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,20 +436,20 @@ def call_dit(
if is_edit:
model_pred = model_pred[:, :img_seq_len]

# unpack latents
model_pred = qwen_image_utils.unpack_latents(
# flow matching loss - compute loss on raw model output before unpacking
latents = latents.to(device=accelerator.device, dtype=network_dtype)
target = noise - latents

# unpack latents for loss calculation
model_pred_unpacked = qwen_image_utils.unpack_latents(
model_pred,
lat_h * qwen_image_utils.VAE_SCALE_FACTOR,
lat_w * qwen_image_utils.VAE_SCALE_FACTOR,
qwen_image_utils.VAE_SCALE_FACTOR,
)

# flow matching loss
latents = latents.to(device=accelerator.device, dtype=network_dtype)
target = noise - latents

# print(model_pred.dtype, target.dtype)
return model_pred, target
return model_pred_unpacked, target

# endregion model specific

Expand Down
Loading