File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2424 AutoConfig ,
2525 AutoModelForCausalLM ,
2626 AutoModelForTokenClassification ,
27- AutoModelForVision2Seq ,
27+ AutoModelForImageTextToText ,
2828 PretrainedConfig ,
2929 PreTrainedModel ,
3030)
@@ -165,10 +165,10 @@ def process_one_shard(rank, model_state_dict_lst):
165165
166166 if "ForTokenClassification" in architectures [0 ]:
167167 AutoClass = AutoModelForTokenClassification
168+ elif "ForConditionalGeneration" in architectures [0 ]:
169+ AutoClass = AutoModelForImageTextToText
168170 elif "ForCausalLM" in architectures [0 ]:
169171 AutoClass = AutoModelForCausalLM
170- elif "ForConditionalGeneration" in architectures [0 ]:
171- AutoClass = AutoModelForVision2Seq
172172 else :
173173 raise NotImplementedError (f"Unknown architecture { architectures } ." )
174174
Original file line number Diff line number Diff line change @@ -191,14 +191,14 @@ def _build_model_optimizer(
191191 torch_dtype = PrecisionType .to_dtype (fsdp_config .torch_dtype )
192192
193193 if role == "critic" :
194- auto_class = AutoModelForTokenClassification
194+ AutoClass = AutoModelForTokenClassification
195195 elif type (self .model_config ) in AutoModelForImageTextToText ._model_mapping .keys ():
196- auto_class = AutoModelForImageTextToText
196+ AutoClass = AutoModelForImageTextToText
197197 else :
198- auto_class = AutoModelForCausalLM
198+ AutoClass = AutoModelForCausalLM
199199
200200 if (not fsdp_config .enable_rank0_init ) or self .device_mesh .get_local_rank ("fsdp" ) == 0 :
201- model = auto_class .from_pretrained (
201+ model = AutoClass .from_pretrained (
202202 model_config .model_path ,
203203 config = self .model_config ,
204204 torch_dtype = torch_dtype ,
@@ -209,7 +209,7 @@ def _build_model_optimizer(
209209 )
210210 else :
211211 with no_init_weights (), init_empty_weights ():
212- model = auto_class .from_config (
212+ model = AutoClass .from_config (
213213 self .model_config ,
214214 torch_dtype = torch_dtype ,
215215 attn_implementation = "flash_attention_2" ,
You can’t perform that action at this time.
0 commit comments