Skip to content

Commit c4623d9

Browse files
authored
Merge pull request #46 from LouisDo2108/master
add a rename param for download_from_wandb
2 parents b7298bc + 75c9110 commit c4623d9

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

theseus/base/utilities/download.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import urllib.request as urlreq
44

55
import gdown
6+
import os
7+
from pathlib import Path
68

79
from theseus.base.utilities.loggers.observer import LoggerObserver
810

@@ -60,18 +62,28 @@ def download_from_url(url, root=None, filename=None):
6062
return fpath
6163

6264

63-
def download_from_wandb(filename, run_path, save_dir, generate_id_text_file=False):
64-
import wandb
65+
def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_text_file=False):
6566

67+
import wandb
68+
6669
try:
6770
path = wandb.restore(filename, run_path=run_path, root=save_dir)
68-
71+
LOGGER.text("Successfully download {} from wandb run path {}".format(filename, run_path), level=LoggerObserver.INFO)
72+
6973
# Save run id to wandb_id.txt
7074
if generate_id_text_file:
7175
wandb_id = osp.basename(run_path)
7276
with open(osp.join(save_dir, "wandb_id.txt"), "w") as f:
7377
f.write(wandb_id)
74-
78+
79+
if rename:
80+
new_name = str(Path(path.name).resolve().parent / rename)
81+
os.rename(Path(path.name).resolve(), new_name)
82+
LOGGER.text("Saved to {}".format(new_name), level=LoggerObserver.INFO)
83+
return new_name
84+
85+
86+
LOGGER.text("Saved to {}".format((Path(save_dir) / path.name).resolve()), level=LoggerObserver.INFO)
7587
return path.name
7688
except:
7789
LOGGER.text("Failed to download from wandb.", level=LoggerObserver.ERROR)

0 commit comments

Comments
 (0)