|
9 | 9 | import os |
10 | 10 | import json |
11 | 11 | import math |
| 12 | +import re |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class Trainer: |
@@ -494,15 +495,20 @@ def save_weights(self, save_path:str, mapping:dict=None, trainable_only:bool=Fal |
494 | 495 | log_info(f"Only trainable parameters saved and occupy {params_trainable}/{params_all}={ratio:.2f}%") |
495 | 496 |
|
496 | 497 | def save_pretrained(self, save_path:str, weight_map:dict=None, mapping:dict=None): |
497 | | - '''按照预训练模型的key来保存模型, 可供transformers包加载''' |
| 498 | + '''按照预训练模型的key来保存模型, 可供transformers包加载 |
| 499 | +
|
| 500 | + :param save_path: str, 保存的文件/文件夹路径 |
| 501 | + ''' |
498 | 502 | state_dict = dict() |
499 | 503 | for name, child in self.unwrap_model().named_children(): |
500 | 504 | if (name != '') and hasattr(child, 'save_pretrained'): |
501 | 505 | tmp = child.save_pretrained(save_path, weight_map, mapping, write_to_disk=False) |
502 | 506 | state_dict.update(tmp if tmp else {}) |
503 | 507 | else: |
504 | 508 | state_dict.update({f'{name}.{k}': v for k,v in child.state_dict().items()}) |
505 | | - save_checkpoint(state_dict, save_path) |
| 509 | + if len(state_dict) > 0: |
| 510 | + save_dir = None if re.search('\.[a-zA-z0-9]+$', save_path) else save_path |
| 511 | + save_checkpoint(state_dict, os.path.join(save_dir, 'pytorch_model.bin') if save_dir else save_path) |
506 | 512 |
|
507 | 513 | def resume_from_checkpoint(self, save_dir:str=None, model_path:str=None, optimizer_path:str=None, scheduler_path:str=None, |
508 | 514 | steps_params_path:str=None, mapping:dict=None, verbose:int=0, strict:bool=True, **kwargs): |
|
0 commit comments