|
3 | 3 | import urllib.request as urlreq |
4 | 4 |
|
5 | 5 | import gdown |
| 6 | +import os |
| 7 | +from pathlib import Path |
6 | 8 |
|
7 | 9 | from theseus.base.utilities.loggers.observer import LoggerObserver |
8 | 10 |
|
@@ -60,18 +62,28 @@ def download_from_url(url, root=None, filename=None): |
60 | 62 | return fpath |
61 | 63 |
|
62 | 64 |
|
63 | | -def download_from_wandb(filename, run_path, save_dir, generate_id_text_file=False): |
64 | | - import wandb |
| 65 | +def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_text_file=False): |
65 | 66 |
|
| 67 | + import wandb |
| 68 | + |
66 | 69 | try: |
67 | 70 | path = wandb.restore(filename, run_path=run_path, root=save_dir) |
68 | | - |
| 71 | + LOGGER.text("Successfully download {} from wandb run path {}".format(filename, run_path), level=LoggerObserver.INFO) |
| 72 | + |
69 | 73 | # Save run id to wandb_id.txt |
70 | 74 | if generate_id_text_file: |
71 | 75 | wandb_id = osp.basename(run_path) |
72 | 76 | with open(osp.join(save_dir, "wandb_id.txt"), "w") as f: |
73 | 77 | f.write(wandb_id) |
74 | | - |
| 78 | + |
| 79 | + if rename: |
| 80 | + new_name = str(Path(path.name).resolve().parent / rename) |
| 81 | + os.rename(Path(path.name).resolve(), new_name) |
| 82 | + LOGGER.text("Saved to {}".format(new_name), level=LoggerObserver.INFO) |
| 83 | + return new_name |
| 84 | + |
| 85 | + |
| 86 | + LOGGER.text("Saved to {}".format((Path(save_dir) / path.name).resolve()), level=LoggerObserver.INFO) |
75 | 87 | return path.name |
76 | 88 | except: |
77 | 89 | LOGGER.text("Failed to download from wandb.", level=LoggerObserver.ERROR) |
|
0 commit comments