Skip to content

Commit b509090

Browse files
committed
modify save_pretrained
1 parent b3fa83e commit b509090

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name='torch4keras',
10-
version='v0.1.8',
10+
version='v0.1.9',
1111
description='Use torch like keras',
1212
long_description=long_description,
1313
long_description_content_type="text/markdown",

torch4keras/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import json
1111
import math
12+
import re
1213

1314

1415
class Trainer:
@@ -494,15 +495,20 @@ def save_weights(self, save_path:str, mapping:dict=None, trainable_only:bool=Fal
494495
log_info(f"Only trainable parameters saved and occupy {params_trainable}/{params_all}={ratio:.2f}%")
495496

496497
def save_pretrained(self, save_path:str, weight_map:dict=None, mapping:dict=None):
497-
'''按照预训练模型的key来保存模型, 可供transformers包加载'''
498+
'''按照预训练模型的key来保存模型, 可供transformers包加载
499+
500+
:param save_path: str, 保存的文件/文件夹路径
501+
'''
498502
state_dict = dict()
499503
for name, child in self.unwrap_model().named_children():
500504
if (name != '') and hasattr(child, 'save_pretrained'):
501505
tmp = child.save_pretrained(save_path, weight_map, mapping, write_to_disk=False)
502506
state_dict.update(tmp if tmp else {})
503507
else:
504508
state_dict.update({f'{name}.{k}': v for k,v in child.state_dict().items()})
505-
save_checkpoint(state_dict, save_path)
509+
if len(state_dict) > 0:
510+
save_dir = None if re.search('\.[a-zA-z0-9]+$', save_path) else save_path
511+
save_checkpoint(state_dict, os.path.join(save_dir, 'pytorch_model.bin') if save_dir else save_path)
506512

507513
def resume_from_checkpoint(self, save_dir:str=None, model_path:str=None, optimizer_path:str=None, scheduler_path:str=None,
508514
steps_params_path:str=None, mapping:dict=None, verbose:int=0, strict:bool=True, **kwargs):

0 commit comments

Comments
 (0)