Skip to content

Commit 9bf7ea4

Browse files
committed
Init
0 parents  commit 9bf7ea4

26 files changed

Lines changed: 1063 additions & 0 deletions

.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
volume/*
2+
outputs/*
3+
Cache/*
4+
runs/*
5+
6+
*.nii.gz
7+
*.raw
8+
*.mhd
9+
*.stl
10+
*cache*
11+

Readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# CVSSL - Contrastive volumetric self supervised learning

callbacks/metric_evaluator.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2025 lightning-hydra-boilerplate
2+
# Licensed under the MIT License.
3+
4+
"""Callback for logging metrics during validation and test using PyTorch Lightning."""
5+
6+
import torch
7+
from lightning.pytorch import Callback, LightningModule, Trainer
8+
from lightning.pytorch.utilities.types import STEP_OUTPUT
9+
10+
11+
class MetricEvaluator(Callback):
12+
"""Logs custom metrics during validation and test epochs.
13+
14+
Args:
15+
metrics (Dict[str, List[torch.nn.Module]]): A dictionary mapping each stage
16+
("validation", "test") to a list of metric instances.
17+
18+
Example:
19+
{
20+
"validation": [Accuracy(...), F1Score(...)],
21+
"test": [Accuracy(...), F1Score(...)]
22+
}
23+
"""
24+
25+
def __init__(self, metrics: dict[str, list[torch.nn.Module]]) -> None:
26+
self.metrics = metrics
27+
28+
def _reset(self, stage: str) -> None:
29+
"""Reset all metrics for a given stage."""
30+
for metric in self.metrics.get(stage, []):
31+
metric.reset()
32+
33+
def _update(self, stage: str, preds: torch.Tensor, targets: torch.Tensor) -> None:
34+
"""Update all metrics for a given stage with predictions and targets."""
35+
for metric in self.metrics.get(stage, []):
36+
metric.to(preds.device)
37+
metric.update(preds, targets)
38+
39+
def _log(self, stage: str, pl_module: LightningModule) -> None:
40+
"""Compute and log all metrics for a given stage."""
41+
# prefix = "val" if stage == "validation" else "test"
42+
# for metric in self.metrics.get(stage, []):
43+
# name = metric.__class__.__name__.lower()
44+
# value = metric.compute()
45+
# pl_module.log(f"{prefix}_{name}", value, prog_bar=True)
46+
47+
def on_validation_epoch_start(self, _trainer: Trainer, _pl_module: LightningModule) -> None:
48+
"""Reset metrics at the start of validation epoch."""
49+
self._reset("validation")
50+
51+
def on_validation_batch_end(
52+
self,
53+
_trainer: Trainer,
54+
pl_module: LightningModule,
55+
_outputs: STEP_OUTPUT,
56+
batch: dict,
57+
_batch_idx: int,
58+
_dataloader_idx: int = 0,
59+
) -> None:
60+
"""Update validation metrics with each batch."""
61+
x, y = batch[0], batch[1]
62+
#preds = pl_module(x).argmax(dim=1)
63+
#self._update("validation", preds, y)
64+
65+
def on_validation_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> None:
66+
"""Compute and log validation metrics at end of epoch."""
67+
self._log("validation", pl_module)
68+
69+
def on_test_epoch_start(self, _trainer: Trainer, _pl_module: LightningModule) -> None:
70+
"""Reset metrics at the start of test epoch."""
71+
self._reset("test")
72+
73+
def on_test_batch_end(
74+
self,
75+
_trainer: Trainer,
76+
pl_module: LightningModule,
77+
_outputs: STEP_OUTPUT,
78+
batch: dict,
79+
_batch_idx: int,
80+
_dataloader_idx: int = 0,
81+
) -> None:
82+
"""Update test metrics with each batch."""
83+
# x, y = batch["image"], batch["label"]
84+
# preds = pl_module(x).argmax(dim=1)
85+
# self._update("test", preds, y)
86+
87+
def on_test_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> None:
88+
"""Compute and log test metrics at end of epoch."""
89+
self._log("test", pl_module)

configs/data/example_data.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_target_: data.example_data.lightning_datamodule.ExampleDataModule
2+
batch_size: 32
3+
num_workers: 4
4+
path: "/mnt/DATA/**/*.nii.gz"

configs/eval.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
defaults:
2+
- _self_
3+
- model: ???
4+
- data: ???
5+
- trainer: default
6+
- experiment: ???
7+
- hydra: default
8+
9+
mode: eval
10+
ckpt_path: ??? # User must override this
11+
data_split: "test" # one of: test or val
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# @package _global_
2+
3+
defaults:
4+
- override /model: example_model # Override the model to use
5+
- override /data: example_data # Override the data module to use
6+
- override /trainer: base # Override or define trainer settings
7+
8+
# Base configuration
9+
seed: 42
10+
skip_test: False
11+
experiment_name: "exp"
12+
13+
# Model parameters
14+
model:
15+
optimizer:
16+
lr: 0.0005
17+
18+
# Data parameters
19+
data:
20+
batch_size: 18
21+
22+
# Lightning Trainer parameters
23+
trainer:
24+
max_epochs: 100
25+
precision: 32
26+
check_val_every_n_epoch: 4
27+
callbacks:
28+
metric_evaluator:
29+
_target_: callbacks.metric_evaluator.MetricEvaluator
30+
metrics:
31+
validation:
32+
- _target_: torchmetrics.Accuracy
33+
task: "multiclass"
34+
num_classes: ${model.num_classes}
35+
test:
36+
- _target_: torchmetrics.Accuracy
37+
task: "multiclass"
38+
num_classes: ${model.num_classes}

configs/hydra/default.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
defaults:
2+
- override job_logging: colorlog
3+
- override hydra_logging: colorlog
4+
5+
run:
6+
dir: "outputs/${mode}/${experiment_name}-${now:%Y-%m-%d_%H-%M-%S}"
7+
sweep:
8+
dir: "outputs/sweeps/${mode}/${experiment_name}-${now:%Y-%m-%d_%H-%M-%S}"
9+
subdir: ${hydra.job.override_dirname}
10+
job:
11+
config:
12+
override_dirname:
13+
exclude_keys:
14+
- experiment
15+
- params_search
16+
job_logging:
17+
handlers:
18+
file:
19+
filename: ${hydra.runtime.output_dir}/${mode}.log

configs/model/example_model.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
_target_: model.example_model.lightning_module.ExampleLightningModel
2+
num_classes: 10
3+
optimizer:
4+
_target_: torch.optim.Adam
5+
_partial_: true
6+
lr: 0.001
7+
loss_fn:
8+
_target_: torch.nn.functional.cross_entropy
9+
_partial_: true
10+
scheduler:
11+
_target_: torch.optim.lr_scheduler.StepLR
12+
step_size: 5
13+
gamma: 0.1
14+
_partial_: true
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# @package _global_
2+
3+
defaults:
4+
- override /hydra/sweeper: optuna
5+
6+
# Make sure this is aligned with the metric logged in lightning module
7+
optimized_metric: "val_multiclassaccuracy"
8+
9+
# See: https://hydra.cc/docs/plugins/optuna_sweeper/
10+
hydra:
11+
mode: "MULTIRUN"
12+
sweeper:
13+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
14+
storage: null
15+
study_name: null
16+
n_jobs: 1
17+
direction: maximize
18+
n_trials: 3
19+
20+
#You can choose TPEsampler, RandomSampler, GridSampler, etc.
21+
# See: https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
22+
sampler:
23+
_target_: optuna.samplers.TPESampler
24+
seed: 42
25+
n_startup_trials: 3 # number of random sampling runs before optimization starts
26+
27+
# Hyperparameter search space
28+
params:
29+
model.optimizer.lr: interval(0.0001, 0.1)
30+
data.batch_size: choice(32, 64)

configs/predict.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- _self_
3+
- model: ???
4+
- data: ???
5+
- trainer: default
6+
- experiment: ???
7+
- hydra: default
8+
9+
mode: predict
10+
ckpt_path: ??? # User must override this
11+
data_split: "test" # one of: train, test, val, or predict
12+
save_format: csv # json or csv

0 commit comments

Comments
 (0)