Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/musubi_tuner/hv_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,13 +974,14 @@ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=
ARCHITECTURE_HUNYUAN_VIDEO,
time.time(),
title,
None,
args.metadata_reso,
args.metadata_author,
args.metadata_description,
args.metadata_license,
args.metadata_tags,
timesteps=md_timesteps,
is_lora=False,
custom_arch=args.metadata_arch,
)

save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata)
Expand Down
15 changes: 14 additions & 1 deletion src/musubi_tuner/hv_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,12 +2084,13 @@ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=
self.architecture,
time.time(),
title,
None,
args.metadata_reso,
args.metadata_author,
args.metadata_description,
args.metadata_license,
args.metadata_tags,
timesteps=md_timesteps,
custom_arch=args.metadata_arch,
)

metadata_to_save.update(sai_metadata)
Expand Down Expand Up @@ -2839,6 +2840,18 @@ def int_or_float(value):
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--metadata_reso",
type=str,
default=None,
help="resolution for model metadata (e.g., `1024,1024`) / メタデータに書き込まれるモデル解像度(例: `1024,1024`)",
)
parser.add_argument(
"--metadata_arch",
type=str,
default=None,
help="architecture for model metadata / メタデータに書き込まれるモデルアーキテクチャ",
)

# huggingface settings
parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion src/musubi_tuner/qwen_image_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,15 @@ def save_model(
self.architecture,
time.time(),
title,
None,
args.metadata_reso,
args.metadata_author,
args.metadata_description,
args.metadata_license,
args.metadata_tags,
timesteps=md_timesteps,
is_lora=False,
is_edit_plus=args.edit_plus,
custom_arch=args.metadata_arch,
)

metadata_to_save.update(sai_metadata)
Expand Down
6 changes: 5 additions & 1 deletion src/musubi_tuner/qwen_image_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
setup_parser_common,
read_config_from_file,
)
import logging
from musubi_tuner.utils.sai_model_spec import CUSTOM_ARCH_QWEN_IMAGE_EDIT_PLUS

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -53,6 +54,9 @@ def handle_model_specific_args(self, args):
self.default_guidance_scale = 1.0 # not used
self.is_edit = args.edit or args.edit_plus

if args.custom_arch is None and args.edit_plus:
args.custom_arch = CUSTOM_ARCH_QWEN_IMAGE_EDIT_PLUS # to notify Edit-Plus mode for sai_model_spec

def process_sample_prompts(
self,
args: argparse.Namespace,
Expand Down
29 changes: 25 additions & 4 deletions src/musubi_tuner/utils/sai_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
ARCH_FLUX_KONTEXT = "Flux.1-dev"
ARCH_QWEN_IMAGE = "Qwen-Image"
ARCH_QWEN_IMAGE_EDIT = "Qwen-Image-Edit"
ARCH_QWEN_IMAGE_EDIT_PLUS = "Qwen-Image-Edit-Plus"
CUSTOM_ARCH_QWEN_IMAGE_EDIT_PLUS = "@@Qwen-Image-Edit-Plus@@" # special custom architecture name for Qwen-Image-Edit-Plus

ADAPTER_LORA = "lora"

Expand Down Expand Up @@ -118,14 +120,15 @@ def build_metadata(
architecture: str,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
reso: Optional[Union[str, int, Tuple[int, int]]] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
is_lora: bool = True,
custom_arch: Optional[str] = None,
):
metadata = {}
metadata.update(BASE_METADATA)
Expand All @@ -152,11 +155,23 @@ def build_metadata(
arch = ARCH_QWEN_IMAGE
impl = IMPL_QWEN_IMAGE
elif architecture == ARCHITECTURE_QWEN_IMAGE_EDIT:
arch = ARCH_QWEN_IMAGE_EDIT
# We treat Qwen-Image-Edit and Qwen-Image-Edit-Plus the same for architecture and implementation
# So we must distinguish them by custom_arch if needed
impl = IMPL_QWEN_IMAGE_EDIT
if custom_arch is None:
arch = ARCH_QWEN_IMAGE_EDIT
elif custom_arch == CUSTOM_ARCH_QWEN_IMAGE_EDIT_PLUS:
arch = ARCH_QWEN_IMAGE_EDIT_PLUS
custom_arch = None # clear custom_arch to avoid override later
else:
arch = ARCH_QWEN_IMAGE_EDIT # override with custom_arch later
else:
raise ValueError(f"Unknown architecture: {architecture}")

# Override with custom architecture if provided
if custom_arch is not None:
arch = custom_arch

if is_lora:
arch += f"/{ADAPTER_LORA}"
metadata["modelspec.architecture"] = arch
Expand Down Expand Up @@ -207,8 +222,14 @@ def build_metadata(
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
reso = (1280, 720)
# resolution is defined in dataset, so use default here
# Use 1328x1328 for Qwen-Image models, 1024x1024 for Qwen-Image-Edit models, and 1280x720 for others (this is just a placeholder, actual resolution may vary)
if architecture == ARCHITECTURE_QWEN_IMAGE:
reso = (1328, 1328)
elif architecture == ARCHITECTURE_QWEN_IMAGE_EDIT:
reso = (1024, 1024)
else:
reso = (1280, 720)
if isinstance(reso, int):
reso = (reso, reso)

Expand Down