1
1
from torch import nn
2
2
import torch
3
3
from torch4keras .snippets import DottableDict , metric_mapping , get_parameter_device , log_info , log_warn , log_error
4
- from torch4keras .snippets import print_trainable_parameters , colorful , monitor_run_by_email , load , save
4
+ from torch4keras .snippets import print_trainable_parameters , colorful , monitor_run_by_email , load_checkpoint , save_checkpoint
5
5
from torch4keras .callbacks import KerasProgbar , SmoothMetricsCallback , TqdmProgbar , ProgressBar2Progbar , Callback , CallbackList , History
6
6
from collections import OrderedDict
7
7
from typing import Union
@@ -460,7 +460,7 @@ def load_weights(self, load_path:Union[str,tuple,list], strict:bool=True, mappin
460
460
461
461
mapping = mapping or dict ()
462
462
for load_path_i in load_path :
463
- state_dict = load (load_path_i )
463
+ state_dict = load_checkpoint (load_path_i )
464
464
for k in list (state_dict .keys ()):
465
465
if k in mapping :
466
466
state_dict [mapping [k ]] = state_dict .pop (k )
@@ -486,9 +486,7 @@ def save_weights(self, save_path:str, mapping:dict=None, trainable_only:bool=Fal
486
486
if k in mapping :
487
487
state_dict [mapping [k ]] = state_dict .pop (k )
488
488
489
- save_dir = os .path .dirname (save_path )
490
- os .makedirs (save_dir , exist_ok = True )
491
- save (state_dict , save_path )
489
+ save_checkpoint (state_dict , save_path )
492
490
if trainable_only :
493
491
params_all = sum (p .numel () for p in self .unwrap_model ().parameters ())
494
492
params_trainable = sum (p .numel () for p in self .unwrap_model ().parameters () if p .requires_grad )
@@ -497,12 +495,15 @@ def save_weights(self, save_path:str, mapping:dict=None, trainable_only:bool=Fal
497
495
498
496
def save_pretrained (self , save_path :str , weight_map :dict = None , mapping :dict = None ):
499
497
'''按照预训练模型的key来保存模型, 可供transformers包加载'''
498
+ state_dict = dict ()
500
499
for name , child in self .unwrap_model ().named_children ():
501
500
if (name != '' ) and hasattr (child , 'save_pretrained' ):
502
- child .save_pretrained (save_path , weight_map , mapping )
501
+ tmp = child .save_pretrained (save_path , weight_map , mapping , write_to_disk = False )
502
+ state_dict .update (tmp if tmp else {})
503
503
else :
504
- save (child .state_dict (), save_path )
505
-
504
+ state_dict .update ({f'{ name } .{ k } ' : v for k ,v in child .state_dict ().items ()})
505
+ save_checkpoint (state_dict , save_path )
506
+
506
507
def resume_from_checkpoint (self , save_dir :str = None , model_path :str = None , optimizer_path :str = None , scheduler_path :str = None ,
507
508
steps_params_path :str = None , mapping :dict = None , verbose :int = 0 , strict :bool = True , ** kwargs ):
508
509
'''同时加载模型、优化器、训练过程参数
0 commit comments