Skip to content

Commit 2791f4f

Browse files
author
valhassan
committed
config file added
1 parent 916bb74 commit 2791f4f

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

configs/segformer_config_RGB.yaml

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
seed_everything: true
2+
3+
trainer:
4+
accelerator: "gpu"
5+
devices: -1
6+
strategy:
7+
class_path: lightning.pytorch.strategies.DDPStrategy
8+
init_args:
9+
find_unused_parameters: false
10+
gradient_as_bucket_view: true
11+
static_graph: true
12+
gradient_clip_val: 1.0
13+
precision: "16-mixed"
14+
sync_batchnorm: true
15+
logger:
16+
class_path: lightning.pytorch.loggers.mlflow.MLFlowLogger
17+
init_args:
18+
save_dir: /home/valhassa/Projects/geo-deep-learning/logs
19+
log_model: all
20+
experiment_name: "gdl_experiment"
21+
run_name: "gdl_run"
22+
callbacks:
23+
- class_path: lightning.pytorch.callbacks.EarlyStopping
24+
init_args:
25+
monitor: "val_loss"
26+
mode: "min"
27+
verbose: False
28+
patience: 20
29+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
30+
init_args:
31+
monitor: "val_loss"
32+
mode: "min"
33+
save_top_k: 1
34+
filename: "model-{epoch:02d}-{val_loss:.3f}"
35+
- class_path: tools.callbacks.segmentation_visualization.VisualizationCallback
36+
init_args:
37+
max_samples: 3
38+
mean: ${data.init_args.mean}
39+
std: ${data.init_args.std}
40+
data_type_max: ${data.init_args.data_type_max}
41+
num_classes: ${model.init_args.num_classes}
42+
class_colors: ${model.init_args.class_colors}
43+
max_epochs: 10
44+
45+
model:
46+
class_path: tasks_with_models.segmentation_segformer.SegmentationSegformer
47+
init_args:
48+
encoder: "mit_b0" # "mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"
49+
in_channels: 3
50+
weights: imagenet
51+
max_samples: 6
52+
num_classes: 5
53+
mean: ${data.init_args.mean}
54+
std: ${data.init_args.std}
55+
data_type_max: ${data.init_args.data_type_max}
56+
loss:
57+
class_path: segmentation_models_pytorch.losses.DiceLoss
58+
init_args:
59+
mode: "multiclass"
60+
class_labels: ["background", "fore", "hydro", "roads", "buildings"]
61+
class_colors: ["#000000", "#008000", "#0000FF", "#FFFF00", "#FF0000"]
62+
weights_from_checkpoint_path: null
63+
64+
optimizer:
65+
class_path: AdamW
66+
init_args:
67+
lr: 6e-5
68+
69+
lr_scheduler:
70+
class_path: ReduceLROnPlateau
71+
init_args:
72+
monitor: "val_loss"
73+
mode: "min"
74+
factor: 0.1
75+
patience: 10
76+
cooldown: 1
77+
min_lr: 6e-8
78+
79+
data:
80+
class_path: datamodules.imagery_NonGeoDataModule.BlueSkyNonGeoDataModule
81+
init_args:
82+
batch_size: 4
83+
num_workers: 8
84+
data_type_max: 255
85+
patch_size:
86+
- 512
87+
- 512
88+
mean:
89+
- 0.3992
90+
- 0.4283
91+
- 0.3998
92+
std:
93+
- 0.1672
94+
- 0.1800
95+
- 0.1584
96+
csv_root_folder: /export/sata01/wspace/test_dir/multi/all_rgb_data/patches/4cls_RGB
97+
patches_root_folder: /export/sata01/wspace/test_dir/multi/all_rgb_data/patches/4cls_RGB
98+
99+
ckpt_path: null

0 commit comments

Comments
 (0)