-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathn2v_train.py
More file actions
106 lines (85 loc) · 3.87 KB
/
n2v_train.py
File metadata and controls
106 lines (85 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from envutils import ENV, load_env, get_tiff_paths, get_argparser, log
from pathlib import Path
import os
import glob
import tifffile
import matplotlib.pyplot as plt
import numpy as np
from careamics import CAREamist
from careamics.config import create_n2v_configuration
from careamics.lightning import TrainDataModule
def train_n2v(train_dataset_name,
validation_dataset_name,
dataset_folder,
models_folder,
experiment_name,
use_n2v2,
use_augmentations,
patch_size_z=None,
patch_size=64,
batch_size=16,
num_epochs=10,
axes="ZYX"):
train_dataset_folder = os.path.join(dataset_folder, train_dataset_name)
validation_dataset_folder = os.path.join(dataset_folder, validation_dataset_name)
model_folder = os.path.join(models_folder, experiment_name)
os.makedirs(model_folder, exist_ok=True)
config = create_n2v_configuration(
experiment_name=experiment_name,
data_type="tiff",
axes=axes,
patch_size=(patch_size_z, patch_size, patch_size) if patch_size_z is not None else (patch_size, patch_size),
batch_size=batch_size,
num_epochs=num_epochs,
use_n2v2=use_n2v2,
augmentations=[] if not use_augmentations else None
)
data_module = TrainDataModule(
data_config=config.data_config,
train_data=train_dataset_folder,
val_data=validation_dataset_folder,
use_in_memory=False
)
# instantiate a CAREamist
careamist = CAREamist(
source=config,
work_dir=model_folder,
)
# train
careamist.train(
datamodule=data_module,
val_percentage=0.,
val_minimum_split=100, # use 100 patches as validation
)
if __name__ == "__main__":
# Get a parser that include some default ENV VARS overrides
parser = get_argparser(description="Train a N2V model on the given dataset.")
# Add script-specific varibles
parser.add_argument('--train_dataset_name', type=str, help='Dataset Name, as subfolder of the dataset directory containing the .tif files')
parser.add_argument('--validation_dataset_name', type=str, help='Dataset Name, as subfolder of the dataset directory containing the .tif files')
parser.add_argument('--experiment_name', type=str, help='Name of the experiment. Will be used to create corresponding subfolders.')
parser.add_argument('--use_n2v2', action="store_true", help='Whether to use N2V2.')
parser.add_argument('--use_augmentations', action="store_true", help='Whether to use N2V2.')
parser.add_argument('--patch_size_z', type=int, default=None, help="Patch depth dimension")
parser.add_argument('--patch_size', type=int, default=64, help="Patch spatial dimension")
parser.add_argument('--batch_size', type=int, default=16, help="Batch Size")
parser.add_argument('--num_epochs', type=int, default=10, help="Epochs to train")
parser.add_argument('--axes', type=str, default="ZXY", help="Axes used to interpret the TIFF files.")
args = parser.parse_args()
# Set Log Level from arguments
log.setLevel(args.level)
# Load env vars and args overrides into ENV dictionary
load_env(args.env, parser_args=args)
train_n2v(train_dataset_name = args.train_dataset_name,
validation_dataset_name = args.validation_dataset_name,
dataset_folder=ENV.get("DATASET_FOLDER"),
models_folder=ENV.get("MODELS_FOLDER"),
experiment_name=args.experiment_name,
use_n2v2=args.use_n2v2,
use_augmentations=args.use_augmentations,
patch_size_z=args.patch_size_z,
patch_size=args.patch_size,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
axes=args.axes
)