-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtrain_auto_explore.py
More file actions
66 lines (49 loc) · 1.76 KB
/
train_auto_explore.py
File metadata and controls
66 lines (49 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from pathlib import Path
import hydra
from omegaconf import DictConfig, OmegaConf
from tools.model_management import CheckpointDirManager
from auto_explore.src.utils import dump_hydra, FileStructure
from auto_explore.src.trainer import Trainer
if "SLURM_NTASKS" in os.environ:
# Remove SLURM env variables to avoid issues with Lightning
del os.environ["SLURM_NTASKS"]
del os.environ["SLURM_JOB_NAME"]
from lightning.fabric import Fabric
from tools.logger import getLogger
log = getLogger(__name__)
def run(cfg: DictConfig):
fabric = Fabric(strategy="ddp", accelerator=cfg.common.device, devices="auto", precision="bf16-mixed")
fabric.launch()
fabric.barrier()
cfg.common.root_dpath = os.path.abspath(cfg.common.root_dpath)
cfg.world_model.root_dpath = os.path.abspath(cfg.world_model.root_dpath)
root_dpath = cfg.common.root_dpath
cdm = CheckpointDirManager(root_dpath)
if cfg.common.resume:
dpath = cdm.get_last_dpath()
else:
if fabric.is_global_zero:
dpath = cdm.build_dpath_next(cfg.common.name)
fabric.barrier()
if not fabric.is_global_zero:
cdm.update()
dpath = cdm.get_last_dpath()
log.i(f"Using wm {cfg.world_model.model_dname}")
fname = dpath.name
cfg.wandb.name = fname
dpath = Path(dpath)
os.makedirs(dpath, exist_ok=True)
os.chdir(dpath)
log.i(f"Running experiment: {fname}")
fs = FileStructure(dpath)
log.i("Data path:", dpath)
fs.create()
dump_hydra(cfg, fs.hydra_config_fpath)
trainer = Trainer(cfg, fabric, fs)
trainer.run()
@hydra.main(config_path="auto_explore/configs", config_name="trainer")
def main(cfg: DictConfig):
run(cfg)
if __name__ == "__main__":
main()