Skip to content

Commit 1a838eb

Browse files
authored
[misc] fix model merger (#479)
1 parent effbf71 commit 1a838eb

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

scripts/model_merger.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
AutoConfig,
2525
AutoModelForCausalLM,
2626
AutoModelForTokenClassification,
27-
AutoModelForVision2Seq,
27+
AutoModelForImageTextToText,
2828
PretrainedConfig,
2929
PreTrainedModel,
3030
)
@@ -165,10 +165,10 @@ def process_one_shard(rank, model_state_dict_lst):
165165

166166
if "ForTokenClassification" in architectures[0]:
167167
AutoClass = AutoModelForTokenClassification
168+
elif "ForConditionalGeneration" in architectures[0]:
169+
AutoClass = AutoModelForImageTextToText
168170
elif "ForCausalLM" in architectures[0]:
169171
AutoClass = AutoModelForCausalLM
170-
elif "ForConditionalGeneration" in architectures[0]:
171-
AutoClass = AutoModelForVision2Seq
172172
else:
173173
raise NotImplementedError(f"Unknown architecture {architectures}.")
174174

verl/workers/fsdp_workers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,14 @@ def _build_model_optimizer(
191191
torch_dtype = PrecisionType.to_dtype(fsdp_config.torch_dtype)
192192

193193
if role == "critic":
194-
auto_class = AutoModelForTokenClassification
194+
AutoClass = AutoModelForTokenClassification
195195
elif type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
196-
auto_class = AutoModelForImageTextToText
196+
AutoClass = AutoModelForImageTextToText
197197
else:
198-
auto_class = AutoModelForCausalLM
198+
AutoClass = AutoModelForCausalLM
199199

200200
if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank("fsdp") == 0:
201-
model = auto_class.from_pretrained(
201+
model = AutoClass.from_pretrained(
202202
model_config.model_path,
203203
config=self.model_config,
204204
torch_dtype=torch_dtype,
@@ -209,7 +209,7 @@ def _build_model_optimizer(
209209
)
210210
else:
211211
with no_init_weights(), init_empty_weights():
212-
model = auto_class.from_config(
212+
model = AutoClass.from_config(
213213
self.model_config,
214214
torch_dtype=torch_dtype,
215215
attn_implementation="flash_attention_2",

0 commit comments

Comments
 (0)