Skip to content

Commit d07e1cf

Browse files
committed
debug save_pretrained
1 parent 85c7497 commit d07e1cf

File tree

5 files changed

+19
-13
lines changed

5 files changed

+19
-13
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pip install git+https://github.com/Tongjilibo/torch4keras.git
6767
## 4. 版本历史
6868
|更新日期| 版本 | 版本说明 |
6969
|------| ----------------- |----------- |
70-
|20240116|v0.1.8 | 重新整理snippets|
70+
|20240116|v0.1.8 | 重新整理snippets, 重写save_pretrained|
7171
|20231219|v0.1.7 | 增加SimpleStreamFileLogger和LoggerHandler, 修改Logger的格式|
7272
|20231208|v0.1.6.post2 |监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题; 修复clip_grad_norm的bug|
7373
|20230928|v0.1.5 |进度条中显示已经训练的时间|

docs/History.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 更新历史
22

3-
- **20240116**: 重新整理snippets
3+
- **20240116**: 重新整理snippets, 重写save_pretrained
44
- **20231219**: 增加SimpleStreamFileLogger和LoggerHandler, 修改Logger的格式
55
- **20231208**: 监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题
66
- **20230928**: 进度条中显示已经训练的时间

examples/tutorials_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def evaluate(self):
7676
email = EmailCallback(mail_receivers='[email protected]') # 发送邮件
7777
wandb = WandbCallback(save_code=True) # wandb
7878
hist = model.fit(train_dataloader, steps_per_epoch=steps_per_epoch, epochs=epochs,
79-
callbacks=[Summary(), evaluator, logger, ckpt, early_stop])
79+
callbacks=[Summary(), evaluator, ts_board, logger, ckpt, early_stop])
8080
else:
8181
model.load_weights('./ckpt/5/model.pt')
8282
metrics = MyEvaluator().evaluate()

torch4keras/snippets/data_process.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch.utils.data import Dataset, IterableDataset
66
import inspect
77
from .import_utils import is_safetensors_available, is_sklearn_available
8+
import os
9+
810

911
if is_safetensors_available():
1012
from safetensors import safe_open
@@ -196,7 +198,7 @@ def metric_mapping(metric, func, y_pred, y_true):
196198
return None
197199

198200

199-
def load(checkpoint:str, load_safetensors:bool=False):
201+
def load_checkpoint(checkpoint:str, load_safetensors:bool=False):
200202
'''加载ckpt,支持torch.load和safetensors
201203
'''
202204
if load_safetensors or checkpoint.endswith(".safetensors"):
@@ -218,9 +220,12 @@ def load(checkpoint:str, load_safetensors:bool=False):
218220
return torch.load(checkpoint, map_location='cpu')
219221

220222

221-
def save(state_dict:dict, save_path:str, save_safetensors:bool=False):
223+
def save_checkpoint(state_dict:dict, save_path:str, save_safetensors:bool=False):
222224
'''保存ckpt,支持torch.save和safetensors
223225
'''
226+
save_dir = os.path.dirname(save_path)
227+
os.makedirs(save_dir, exist_ok=True)
228+
224229
if save_safetensors or save_path.endswith('.safetensors'):
225230
safe_save_file(state_dict, save_path, metadata={"format": "pt"})
226231
else:

torch4keras/trainer.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch import nn
22
import torch
33
from torch4keras.snippets import DottableDict, metric_mapping, get_parameter_device, log_info, log_warn, log_error
4-
from torch4keras.snippets import print_trainable_parameters, colorful, monitor_run_by_email, load, save
4+
from torch4keras.snippets import print_trainable_parameters, colorful, monitor_run_by_email, load_checkpoint, save_checkpoint
55
from torch4keras.callbacks import KerasProgbar, SmoothMetricsCallback, TqdmProgbar, ProgressBar2Progbar, Callback, CallbackList, History
66
from collections import OrderedDict
77
from typing import Union
@@ -460,7 +460,7 @@ def load_weights(self, load_path:Union[str,tuple,list], strict:bool=True, mappin
460460

461461
mapping = mapping or dict()
462462
for load_path_i in load_path:
463-
state_dict = load(load_path_i)
463+
state_dict = load_checkpoint(load_path_i)
464464
for k in list(state_dict.keys()):
465465
if k in mapping:
466466
state_dict[mapping[k]] = state_dict.pop(k)
@@ -486,9 +486,7 @@ def save_weights(self, save_path:str, mapping:dict=None, trainable_only:bool=Fal
486486
if k in mapping:
487487
state_dict[mapping[k]] = state_dict.pop(k)
488488

489-
save_dir = os.path.dirname(save_path)
490-
os.makedirs(save_dir, exist_ok=True)
491-
save(state_dict, save_path)
489+
save_checkpoint(state_dict, save_path)
492490
if trainable_only:
493491
params_all = sum(p.numel() for p in self.unwrap_model().parameters())
494492
params_trainable = sum(p.numel() for p in self.unwrap_model().parameters() if p.requires_grad)
@@ -497,12 +495,15 @@ def save_weights(self, save_path:str, mapping:dict=None, trainable_only:bool=Fal
497495

498496
def save_pretrained(self, save_path:str, weight_map:dict=None, mapping:dict=None):
499497
'''按照预训练模型的key来保存模型, 可供transformers包加载'''
498+
state_dict = dict()
500499
for name, child in self.unwrap_model().named_children():
501500
if (name != '') and hasattr(child, 'save_pretrained'):
502-
child.save_pretrained(save_path, weight_map, mapping)
501+
tmp = child.save_pretrained(save_path, weight_map, mapping, write_to_disk=False)
502+
state_dict.update(tmp if tmp else {})
503503
else:
504-
save(child.state_dict(), save_path)
505-
504+
state_dict.update({f'{name}.{k}': v for k,v in child.state_dict().items()})
505+
save_checkpoint(state_dict, save_path)
506+
506507
def resume_from_checkpoint(self, save_dir:str=None, model_path:str=None, optimizer_path:str=None, scheduler_path:str=None,
507508
steps_params_path:str=None, mapping:dict=None, verbose:int=0, strict:bool=True, **kwargs):
508509
'''同时加载模型、优化器、训练过程参数

0 commit comments

Comments
 (0)