Skip to content

Commit abba9e1

Browse files
committed
v0.1.8
1 parent d07e1cf commit abba9e1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch4keras/snippets/data_process.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def save_checkpoint(state_dict:dict, save_path:str, save_safetensors:bool=False)
224224
'''保存ckpt,支持torch.save和safetensors
225225
'''
226226
save_dir = os.path.dirname(save_path)
227-
os.makedirs(save_dir, exist_ok=True)
227+
if save_dir:
228+
os.makedirs(save_dir, exist_ok=True)
228229

229230
if save_safetensors or save_path.endswith('.safetensors'):
230231
safe_save_file(state_dict, save_path, metadata={"format": "pt"})

0 commit comments

Comments
 (0)