Skip to content

Commit 835bf4a

Browse files
authored
Use PyTorch lightning for training and Hydra + OmegaConf for configuration (#6)
* update requirements.txt * make load_modules.sh executable * enable lr schedules in VQVAE * use lightning and hydra+omegaconf for training script
1 parent 1718fcf commit 835bf4a

File tree

18 files changed

+445
-174
lines changed

18 files changed

+445
-174
lines changed

configs/hydra/default.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# https://hydra.cc/docs/configure_hydra/intro/
2+
3+
# enable color logging by setting to 'colorlog' -- if set to 'none', logging will not
4+
# be modified by hydra (i.e. then the logging config from the code will be used)
5+
defaults:
6+
- override hydra_logging: none
7+
- override job_logging: none
8+
9+
10+
# output directory, generated dynamically on each run
11+
run:
12+
dir: ./outputs/${project_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}_${task_name}_${run_name}
13+
14+
# if you want to disable automatic output directory creation, set run.dir to "."
15+
# run:
16+
# dir: .
17+
# output_subdir: null # if set, will be appended to run.dir. Default is .hydra
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
version: 1
2+
formatters:
3+
simple:
4+
format: '%(asctime)s - %(levelname)s - %(message)s'
5+
colorlog:
6+
class: colorlog.ColoredFormatter
7+
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
8+
datefmt: '%Y-%m-%d %H:%M:%S'
9+
log_colors:
10+
DEBUG: 'cyan'
11+
INFO: 'green'
12+
WARNING: 'yellow'
13+
ERROR: 'red'
14+
CRITICAL: 'bold_red'
15+
handlers:
16+
console:
17+
class: logging.StreamHandler
18+
formatter: colorlog
19+
stream: ext://sys.stdout
20+
level: INFO
21+
root:
22+
handlers: [console]
23+
24+
disable_existing_loggers: false

configs/main.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- _self_
3+
- hydra: default
4+
- model: vqvae
5+
- data: cldhits
6+
- trainer: ddp
7+
- ml_logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
8+
- paths: default
9+
10+
project_name: dev
11+
run_name: main
12+
task_name: train

configs/ml_logger/wandb.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
wandb:
2+
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
3+
# name: "" # name of the run (normally generated by wandb)
4+
save_dir: "${paths.output_dir}"
5+
offline: False
6+
id: null # pass correct id to resume experiment!
7+
anonymous: null # enable anonymous logging
8+
project: "deep-learning"
9+
log_model: False # upload lightning ckpts
10+
prefix: "" # a string to put at the beginning of metric keys
11+
# entity: "" # set to name of your wandb team
12+
group: ""
13+
tags: []
14+
job_type: ""

configs/model/vqvae.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# _target_: src.models.vqvae.VQVAELightning
2+
3+
model_name: VQVAELightning
4+
5+
model_type: "VQVAENormFormer"
6+
7+
model_kwargs:
8+
input_dim: 3
9+
hidden_dim: 128
10+
latent_dim: 16
11+
num_blocks: 3
12+
num_heads: 8
13+
alpha: 5
14+
vq_kwargs:
15+
num_codes: 2048
16+
beta: 0.9
17+
kmeans_init: true
18+
norm: null
19+
cb_norm: null
20+
affine_lr: 0.0
21+
sync_nu: 2
22+
replace_freq: 20
23+
dim: -1
24+
25+
optimizer:
26+
_target_: torch.optim.AdamW
27+
_partial_: true
28+
lr: 0.001
29+
# weight_decay: 0.05
30+
31+
optimizer_kwargs:
32+
lr: 0.001,
33+
weight_decay: float = 0.001,
34+
amsgrad: bool = False,
35+
36+
scheduler:
37+
_target_: torch.optim.lr_scheduler.ConstantLR
38+
_partial_: true
39+
40+
# using the method listed in the paper https://arxiv.org/abs/1902.08570, but with other parameters
41+
# scheduler:
42+
# _target_: src.schedulers.lr_scheduler.OneCycleCooldown
43+
# _partial_: true
44+
# warmup: 4
45+
# cooldown: 10
46+
# cooldown_final: 10
47+
# max_lr: 0.0002
48+
# initial_lr: 0.00003
49+
# final_lr: 0.00002
50+
# max_iters: 200

configs/paths/default.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# path to root directory
2+
# this requires PROJECT_ROOT environment variable to exist
3+
# you can replace it with "." if you want the root to be the current working directory
4+
# root_dir: ${oc.env:PROJECT_ROOT}
5+
6+
# path to data directory
7+
# data_dir: ${oc.env:DATA_DIR}
8+
9+
# path to logging directory
10+
# log_dir: ${oc.env:LOG_DIR}
11+
12+
# path to output directory, created dynamically by hydra
13+
# path generation pattern is specified in `configs/hydra/default.yaml`
14+
# use it to store all files generated during the run, like ckpts and metrics
15+
output_dir: ${hydra:run.dir}
16+
17+
# path to working directory
18+
work_dir: ${hydra:runtime.cwd}

configs/trainer/cpu.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
defaults:
2+
- default.yaml
3+
4+
accelerator: cpu
5+
devices: 1

configs/trainer/ddp.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# _target_: lightning.Trainer
2+
3+
defaults:
4+
- default
5+
6+
accelerator: gpu
7+
strategy: ddp
8+
devices: 4
9+
10+
# mixed precision
11+
precision: 16-mixed
12+
13+
# set True to to ensure deterministic results
14+
# makes training slower but gives more reproducibility than just setting seeds
15+
deterministic: False
16+
sync_batchnorm: True

configs/trainer/default.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# _target_: lightning.Trainer # Instantiating with hydra.utils.instantiate may pose a security risk
2+
3+
min_epochs: 1 # prevents early stopping
4+
max_epochs: 10
5+
6+
accelerator: cpu
7+
devices: 1
8+
enable_progress_bar: False
9+
10+
# perform a validation loop every N training epochs
11+
check_val_every_n_epoch: 1
12+
13+
# set True to to ensure deterministic results
14+
# makes training slower but gives more reproducibility than just setting seeds
15+
deterministic: False
16+
17+
# note needed for single device or cpu training
18+
sync_batchnorm: False

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,8 @@ torch==2.5.1
2121
torchvision==0.20.1
2222
torchaudio==2.5.1
2323
nbdev
24+
lightning
25+
tensorboardX
26+
hydra-core
27+
hydra-colorlog
28+
omegaconf

0 commit comments

Comments
 (0)