-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathbcss.yaml
More file actions
131 lines (131 loc) · 4.29 KB
/
bcss.yaml
File metadata and controls
131 lines (131 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
init_args:
refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1}
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
encoder:
class_path: eva.vision.models.ModelFromRegistry
init_args:
model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino}
model_kwargs:
out_indices: ${oc.env:OUT_INDICES, 1}
model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null}
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 6
criterion:
class_path: eva.vision.losses.DiceLoss
init_args:
softmax: true
batch: true
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.002}
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
postprocess:
predictions_transforms:
- class_path: torch.argmax
init_args:
dim: 1
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.GeneralizedDiceScore
init_args:
num_classes: *NUM_CLASSES
weight_type: linear
per_class: true
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.BCSS
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/bcss}
split: train
target_mpp: 0.5
sampler:
class_path: eva.vision.data.wsi.patching.samplers.GridSampler
init_args:
max_samples: 1000
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: *NORMALIZE_MEAN
std: *NORMALIZE_STD
val:
class_path: eva.vision.datasets.BCSS
init_args:
<<: *DATASET_ARGS
split: val
test:
class_path: eva.vision.datasets.BCSS
init_args:
<<: *DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
test:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS