Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions configs/hydra/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# https://hydra.cc/docs/configure_hydra/intro/

# enable color logging by setting to 'colorlog' -- if set to 'none', logging will not
# be modified by hydra (i.e. then the logging config from the code will be used)
defaults:
- override hydra_logging: none
- override job_logging: none


# output directory, generated dynamically on each run
run:
dir: ./outputs/${project_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}_${task_name}_${run_name}

# if you want to disable automatic output directory creation, set run.dir to "."
# run:
# dir: .
# output_subdir: null # if set, will be appended to run.dir. Default is .hydra
24 changes: 24 additions & 0 deletions configs/hydra/job_logging/custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
version: 1
formatters:
simple:
format: '%(asctime)s - %(levelname)s - %(message)s'
colorlog:
class: colorlog.ColoredFormatter
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
datefmt: '%Y-%m-%d %H:%M:%S'
log_colors:
DEBUG: 'cyan'
INFO: 'green'
WARNING: 'yellow'
ERROR: 'red'
CRITICAL: 'bold_red'
handlers:
console:
class: logging.StreamHandler
formatter: colorlog
stream: ext://sys.stdout
level: INFO
root:
handlers: [console]

disable_existing_loggers: false
12 changes: 12 additions & 0 deletions configs/main.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- _self_
- hydra: default
- model: vqvae
- data: cldhits
- trainer: ddp
- ml_logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- paths: default

project_name: dev
run_name: main
task_name: train
14 changes: 14 additions & 0 deletions configs/ml_logger/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
wandb:
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "deep-learning"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""
50 changes: 50 additions & 0 deletions configs/model/vqvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# _target_: src.models.vqvae.VQVAELightning

model_name: VQVAELightning

model_type: "VQVAENormFormer"

model_kwargs:
input_dim: 3
hidden_dim: 128
latent_dim: 16
num_blocks: 3
num_heads: 8
alpha: 5
vq_kwargs:
num_codes: 2048
beta: 0.9
kmeans_init: true
norm: null
cb_norm: null
affine_lr: 0.0
sync_nu: 2
replace_freq: 20
dim: -1

optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.001
# weight_decay: 0.05

optimizer_kwargs:
lr: 0.001,
weight_decay: float = 0.001,
amsgrad: bool = False,

scheduler:
_target_: torch.optim.lr_scheduler.ConstantLR
_partial_: true

# using the method listed in the paper https://arxiv.org/abs/1902.08570, but with other parameters
# scheduler:
# _target_: src.schedulers.lr_scheduler.OneCycleCooldown
# _partial_: true
# warmup: 4
# cooldown: 10
# cooldown_final: 10
# max_lr: 0.0002
# initial_lr: 0.00003
# final_lr: 0.00002
# max_iters: 200
18 changes: 18 additions & 0 deletions configs/paths/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# path to root directory
# this requires PROJECT_ROOT environment variable to exist
# you can replace it with "." if you want the root to be the current working directory
# root_dir: ${oc.env:PROJECT_ROOT}

# path to data directory
# data_dir: ${oc.env:DATA_DIR}

# path to logging directory
# log_dir: ${oc.env:LOG_DIR}

# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:run.dir}

# path to working directory
work_dir: ${hydra:runtime.cwd}
5 changes: 5 additions & 0 deletions configs/trainer/cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- default.yaml

accelerator: cpu
devices: 1
16 changes: 16 additions & 0 deletions configs/trainer/ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# _target_: lightning.Trainer

defaults:
- default

accelerator: gpu
strategy: ddp
devices: 4

# mixed precision
precision: 16-mixed

# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
sync_batchnorm: True
18 changes: 18 additions & 0 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# _target_: lightning.Trainer # Instantiating with hydra.utils.instantiate may pose a security risk

min_epochs: 1 # prevents early stopping
max_epochs: 10

accelerator: cpu
devices: 1
enable_progress_bar: False

# perform a validation loop every N training epochs
check_val_every_n_epoch: 1

# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False

# note needed for single device or cpu training
sync_batchnorm: False
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ torch==2.5.1
torchvision==0.20.1
torchaudio==2.5.1
nbdev
lightning
tensorboardX
hydra-core
hydra-colorlog
omegaconf
Empty file modified scripts/flatiron/load_modules.sh
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion src/datasets/CLDHits.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __iter__(self):

else:

# new code to return per particle, not per event
# return one hit at a time instead of one event
for i in range(len(calo_hit_features)):
if self.nsamples is not None and self.sample_counter >= self.nsamples:
return
Expand Down
42 changes: 42 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import hydra
from omegaconf import DictConfig, OmegaConf

from src.utils.pylogger import configure_root_logger


@hydra.main(version_base="1.3", config_path="../configs", config_name="main.yaml")
def main_hydra_decorated(cfg: DictConfig):
configure_root_logger()
# lazy import to speed up initial loading e.g. when doing python -m src.main --help
from src.train import train

train(cfg)


def main():
"""
Main function to run the training script with Hydra configuration.
This function is a workaround to use Hydra's configuration management
without wrapping the main training code in the hydra.main decorator.
It allows for command-line overrides and configuration loading.
"""
configure_root_logger()

cfg_holder = []

@hydra.main(version_base="1.3", config_path="../configs", config_name="main.yaml")
def parse(cfg):
OmegaConf.resolve(cfg) # If you need resolving, it needs to be done here
cfg_holder.append(cfg)

parse()
cfg = cfg_holder[0]

# Run main code
from src.train import train # lazy import to speed up initial loading

train(cfg)


if __name__ == "__main__":
main_hydra_decorated() # replace with main() if you want to run without @hydra.main
2 changes: 2 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .vqvae import VQVAELightning # noqa: F401
from .self_attention_transformer import SelfAttentionTransformerLightning # noqa: F401
48 changes: 48 additions & 0 deletions src/models/self_attention_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from torch.nn.attention import SDPBackend, sdpa_kernel
import lightning.pytorch as pl


def get_activation(activation):
Expand Down Expand Up @@ -216,3 +217,50 @@ def forward(self, x_orig, mask=None):
x = self.norm(final_embedding)
x = self.fc_out(x)
return x


class SelfAttentionTransformerLightning(pl.LightningModule):
def __init__(self, loss_function, lr, model_kwargs=None):
super(SelfAttentionTransformerLightning, self).__init__()
self.model = SelfAttentionTransformer(**model_kwargs)
self.learning_rate = lr
self.loss_function = loss_function

def forward(self, x, mask=None):
return self.model(x, mask)

def training_step(self, batch, batch_idx):
x = batch["calo_hit_features"]
y = batch["hit_labels"]

y_pred = self(x)
loss = self.loss_function(y_pred, y)

# Logging
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
x = batch["calo_hit_features"]
y = batch["hit_labels"]

y_pred = self(x)
loss = self.loss_function(y_pred, y)

# Logging
self.log("val_loss", loss)

def test_step(self, batch, batch_idx):
x = batch["calo_hit_features"]
y = batch["hit_labels"]

y_pred = self(x)
loss = self.loss_function(y_pred, y)

# Logging
self.log("test_loss", loss)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
return [optimizer] # , [scheduler]
13 changes: 7 additions & 6 deletions src/models/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import time
from pathlib import Path
from typing import Any, Dict, Tuple
from typing import Tuple

import lightning as L
import matplotlib.pyplot as plt
Expand All @@ -14,7 +14,12 @@
import vector
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from vqtorch.nn import VectorQuant

# vqtorch can be installed from https://github.com/minyoungg/vqtorch
try:
from vqtorch.nn import VectorQuant # type: ignore
except ImportError as e:
raise ImportError("vqtorch is not installed. Please install it to use this module.") from e

from src.utils.arrays import (
ak_pad,
Expand Down Expand Up @@ -342,8 +347,6 @@ def __init__(
super().__init__()
self.save_hyperparameters(logger=False)

# --------------- load pretrained model --------------- #
# if kwargs.get("load_pretrained", False):
if model_type == "MLP":
self.model = VQVAEMLP(**model_kwargs)
elif model_type == "Transformer":
Expand Down Expand Up @@ -371,7 +374,6 @@ def __init__(

def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.model.parameters(), **self.optimizer_kwargs)
"""
if self.lr_scheduler:
return {
"optimizer": optimizer,
Expand All @@ -381,7 +383,6 @@ def configure_optimizers(self):
"frequency": self.lr_scheduler_frequency,
},
}
"""
return optimizer

def forward(self, x_particle, mask_particle):
Expand Down
Loading