Skip to content

[BUG]Qwen-edit-plus training with sample #654

@sdbds

Description

@sdbds
Extended arguments:
  --text_encoder=./ckpts/text_encoder/qwen_2.5_vl_7b.safetensors
  --edit_plus
  --flash_attn
  --split_attn
  --timestep_sampling=qinglong_qwen
  --guidance_scale=1
  --max_train_epochs=20
  --gradient_checkpointing
  --gradient_accumulation_steps=1
  --network_dim=32
  --network_alpha=16
  --network_module=networks.lora_qwen_image
  --gradient_accumulation_steps=1
  --lr_scheduler=cosine_with_min_lr
  --lr_scheduler_num_cycles=1
  --lr_decay_steps=0.2
  --lr_scheduler_min_lr_ratio=0.1
  --mixed_precision=bf16
  --fp8_base
  --fp8_scaled
  --persistent_data_loader_workers
  --blocks_to_swap=6
  --img_in_txt_in_offloading
  --optimizer_type=adv_optm.Adopt_adv
  --optimizer_args
  use_atan2=True
  grams_moment=True
  --save_every_n_epochs=2
  --sample_prompts=1.txt
  --sample_every_n_epochs=2

only happend when using sample.

Traceback (most recent call last):
  File "D:\musubi-tuner-scripts\musubi-tuner\qwen_image_train_network.py", line 4, in <module>
    main()
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\qwen_image_train_network.py", line 480, in main
    trainer.train(args)
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\hv_train_network.py", line 2153, in train
    model_pred, target = self.call_dit(
                         ^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\qwen_image_train_network.py", line 427, in call_dit
    model_pred = model(
                 ^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\accelerate\utils\operations.py", line 818, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\accelerate\utils\operations.py", line 806, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\amp\autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\qwen_image\qwen_image_model.py", line 1157, in forward
    encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\qwen_image\qwen_image_model.py", line 1077, in _gradient_checkpointing_func
    return torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\_dynamo\eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\utils\checkpoint.py", line 503, in checkpoint
    ret = function(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\qwen_image\qwen_image_model.py", line 855, in forward
    img_mod_params = self.img_mod(temb)  # [B, 6*dim]
                     ^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\container.py", line 250, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\networks\lora.py", line 104, in forward
    org_forwarded = self.org_forward(x)
                    ^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\modules\fp8_optimization_utils.py", line 461, in new_forward
    return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\musubi-tuner-scripts\musubi-tuner\src\musubi_tuner\modules\fp8_optimization_utils.py", line 403, in fp8_linear_forward_patch
    dequantized_weight = dequantized_weight * self.scale_weight
                         ~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

After testing, I found that the problem was caused by the advanced optimizer...
#598

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions