Skip to content

Commit 3b86304

Browse files
committed
屏蔽torch.load警告, Timeit提示
1 parent d9d5b72 commit 3b86304

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

torch4keras/snippets/data_process.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ def load_checkpoint(checkpoint:str, load_safetensors:bool=False):
229229
return safe_load_file(checkpoint)
230230
else:
231231
# 正常加载pytorch_model.bin
232-
return torch.load(checkpoint, map_location='cpu')
232+
if 'weights_only' in inspect.signature(torch.load).parameters:
233+
return torch.load(checkpoint, map_location='cpu', weights_only=True)
234+
else:
235+
return torch.load(checkpoint, map_location='cpu')
233236

234237

235238
def save_checkpoint(state_dict:dict, save_path:str, save_safetensors:bool=False):

torch4keras/snippets/monitor.py

+12
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
152152
self.lap(name='Total', reset=False)
153153

154154

155+
class Timeit(TimeitContextManager):
156+
def __init__(self) -> None:
157+
super().__init__()
158+
raise DeprecationWarning('`Timeit` has been deprecated since torch4keras==v0.2.8, use `TimeitContextManager` instead')
159+
160+
155161
class TimeitLogger:
156162
'''记录耗时
157163
@@ -209,6 +215,12 @@ def end(self, verbose=1):
209215
return cost
210216

211217

218+
class Timeit2(TimeitLogger):
219+
def __init__(self) -> None:
220+
super().__init__()
221+
raise DeprecationWarning('`Timeit2` has been deprecated since torch4keras==v0.2.8, use `TimeitLogger` instead')
222+
223+
212224
def send_email(mail_receivers:Union[str,list], mail_subject:str, mail_msg:str="", mail_host:str=None,
213225
mail_user:str=None, mail_pwd:str=None, mail_sender:str=None):
214226
''' 发送邮件(默认使用笔者自己注册的邮箱, 若含敏感信息请使用自己注册的邮箱)

0 commit comments

Comments
 (0)