Skip to content

Commit 9a900bf

Browse files
authored
[megatron] support megatron all-router multimodal (modelscope#7951)
1 parent bed4d10 commit 9a900bf

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

swift/megatron/trainers/base.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,19 +1118,28 @@ def unmerge_lora_adapters(self):
11181118
module.unmerge()
11191119

11201120
@staticmethod
1121-
def _copy_args(output_dir):
1122-
if is_last_rank():
1123-
args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
1124-
if os.path.exists(args_path):
1125-
shutil.copy(args_path, os.path.join(output_dir, 'args.json'))
1121+
def copy_path(src_path: str, tgt_path: str):
1122+
if not is_last_rank():
1123+
return
1124+
if not os.path.exists(src_path):
1125+
raise FileNotFoundError(f'Source path does not exist: {src_path}')
1126+
1127+
if os.path.isfile(src_path):
1128+
os.makedirs(os.path.dirname(tgt_path), exist_ok=True)
1129+
shutil.copy(src_path, tgt_path)
1130+
elif os.path.isdir(src_path):
1131+
shutil.copytree(src_path, tgt_path, dirs_exist_ok=True)
1132+
else:
1133+
raise ValueError(f'Source path is neither a file nor a directory: {src_path}')
11261134

11271135
def save_checkpoint(self, iteration, model, *_args, **kwargs):
11281136
args = get_args()
11291137
output_dir = os.path.join(args.save, f'checkpoint-{iteration}')
11301138
os.makedirs(output_dir, exist_ok=True)
11311139
origin_save = args.save
11321140
args.save = output_dir
1133-
self._copy_args(output_dir)
1141+
args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
1142+
self.copy_path(args_path, os.path.join(output_dir, 'args.json'))
11341143
save_peft_format = args.tuner_type == 'lora' and not args.merge_lora
11351144
if args.save_safetensors and args.no_save_optim:
11361145
model = []
@@ -1142,9 +1151,17 @@ def save_checkpoint(self, iteration, model, *_args, **kwargs):
11421151
# merge-lora does not store lora, lora saving may report an error (Qwen3-VL-Moe)
11431152
if args.tuner_type == 'lora' and args.merge_lora:
11441153
self.merge_lora_adapters()
1154+
origin_output_dir = output_dir
11451155
output_dir = f'{output_dir}-merged'
11461156
os.makedirs(output_dir, exist_ok=True)
1147-
self._copy_args(output_dir)
1157+
for fname in ['latest_checkpointed_iteration.txt', 'args.json']:
1158+
src_path = os.path.join(origin_output_dir, fname)
1159+
self.copy_path(src_path, os.path.join(output_dir, fname))
1160+
# common.pt
1161+
common_path = os.path.join(origin_output_dir, f'iter_{iteration:07d}', 'common.pt')
1162+
tgt_common_path = os.path.join(output_dir, f'iter_{iteration:07d}', 'common.pt')
1163+
os.makedirs(os.path.dirname(tgt_common_path), exist_ok=True)
1164+
self.copy_path(common_path, tgt_common_path)
11481165
self.bridge.save_weights(
11491166
self.unwrapped_models,
11501167
output_dir,

swift/megatron/tuners/lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
202202
lora_b.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap
203203
for lora in [lora_a, lora_b]:
204204
if getattr(lora, 'parallel_mode', None) is None and hasattr(lora, 'weight'): # TODO: experts
205-
sequence_parallel = True if isinstance(self.base_layer, TopKRouter) else self.sequence_parallel
205+
if isinstance(self.base_layer, TopKRouter):
206+
sequence_parallel = self.base_layer.weight.sequence_parallel
207+
else:
208+
sequence_parallel = self.sequence_parallel
206209
lora.weight.sequence_parallel = sequence_parallel
207210
self.lora_A[adapter_name] = lora_a
208211
self.lora_B[adapter_name] = lora_b

swift/megatron/utils/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
logger = get_logger()
2929

3030

31-
def find_all_linears(model):
31+
def find_all_linears(model, extra_layers=None):
3232

3333
def _cond(name, module):
34-
if name != 'output_layer' and isinstance(
34+
if (extra_layers and isinstance(module, tuple(extra_layers))) or name != 'output_layer' and isinstance(
3535
module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, nn.Linear)):
3636
return True
3737
return False
@@ -54,6 +54,8 @@ def get_multimodal_target_regex(
5454
freeze_llm: bool = False,
5555
freeze_vit: bool = True,
5656
freeze_aligner: bool = True,
57+
include_embedding: bool = False,
58+
include_router: bool = False,
5759
) -> str:
5860
from ..model import get_megatron_model_meta
5961
megatron_model_meta = get_megatron_model_meta(args.hf_model_type)
@@ -68,6 +70,11 @@ def get_multimodal_target_regex(
6870
if not freeze_aligner:
6971
modules += aligner
7072
assert len(modules) > 0, f'modules: {modules}'
73+
extra_layers = []
74+
if include_embedding:
75+
extra_layers.append(LanguageModelEmbedding)
76+
if include_router:
77+
extra_layers.append(TopKRouter)
7178

7279
res = []
7380
for module in modules:
@@ -80,13 +87,13 @@ def get_multimodal_target_regex(
8087
sub_module = deep_getattr(model, module)
8188
if sub_module is None:
8289
continue
83-
target_modules = find_all_linears(sub_module)
90+
target_modules = find_all_linears(sub_module, extra_layers)
8491
if not target_modules:
8592
continue
8693
target_modules = [tm for tm in target_modules if tm]
8794
target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
8895
rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
89-
res.append(rf'{rejected_pattern}{module}{target_pattern}')
96+
res.append(rf'{rejected_pattern}{module}(?=\.){target_pattern}')
9097

9198
return rf'^({"|".join(res)})$'
9299

@@ -103,6 +110,8 @@ def get_target_modules(args, model):
103110
freeze_llm=args.freeze_llm,
104111
freeze_vit=args.freeze_vit,
105112
freeze_aligner=args.freeze_aligner,
113+
include_embedding='all-embedding' in target_modules,
114+
include_router='all-router' in target_modules,
106115
)
107116
else:
108117
target_modules.remove('all-linear')

swift/tuners/peft.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional
8686

8787

8888
def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
89-
all_supported_names = ('linear', )
90-
all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D, lora.Linear)
91-
target_modules = getattr(peft_config, 'target_modules', None)
92-
target_parameters = getattr(peft_config, 'target_parameters', None)
9389
if target is None:
9490
return
9591

96-
if isinstance(target_modules, str) and not any(
97-
[name in target.__class__.__name__.lower()
98-
for name in all_supported_names]) and not any([isinstance(target, type_)
99-
for type_ in all_supported_types]) and not target_parameters:
100-
return
101-
10292
if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
10393
return
10494

0 commit comments

Comments
 (0)