Skip to content

Commit 2f83327

Browse files
authored
Merge pull request #1 from flatironinstitute/main (#19)
* Add files via upload * boundary augmentation * commented reneu * commentted reneu * added * stuff * made change * tensorboard * here * changes * changes * change * yes * changes
1 parent 2ca781c commit 2f83327

File tree

8 files changed

+214
-38
lines changed

8 files changed

+214
-38
lines changed

examples/config.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
system:
2+
cpus: 1
3+
gpus: 1
4+
seed: 1
5+
6+
dataset:
7+
training:
8+
s3vol01700:
9+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_01700/img.h5",]
10+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_01700/label_v3.h5"
11+
s3vol02299:
12+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02299/img.h5",]
13+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02299/label_v3.h5"
14+
s3vol02400:
15+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02400/img_zyx_2400-2656_5700-5956_2770-3026.h5",]
16+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02400/label_v1_diane.h5"
17+
# s3vol02684:
18+
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02284/img_zyx_2400-2656_5700-5956_2770-3026.h5",]
19+
# label: "~/dropbox/40_gt/13_wasp_sample3/vol_0684/label_v1_diane.h5"
20+
s3vol02794:
21+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02794/img_zyx_2794-3050_5811-6067_8757-9013.h5",]
22+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02794/seg_v1_cropped.h5"
23+
s3vol03290:
24+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03290/img_zyx_3290-3546_2375-2631_8450-8706.h5",]
25+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03290/label_v1.h5"
26+
s3vol03700:
27+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03700/img_zyx_3700-3956_5000-5256_4250-4506.h5",]
28+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03700/label_v3.h5"
29+
s3vol03998:
30+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03998/img.h5",]
31+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03998/label_v1.h5"
32+
s3vol04900:
33+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04900/img.h5",]
34+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_04900/label_v1.h5"
35+
s3vol05250:
36+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05250/img_zyx_5250-5506_4600-4856_5500-5756.h5",]
37+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_05250/label_v3_remove_contact.h5"
38+
s3vol05450:
39+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05450/img_zyx_5450-5706_5350-5606_7000-7256.h5",]
40+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_05450/label_v4_chiyip.h5"
41+
validation:
42+
s3vol04000:
43+
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04000/img_zyx_4000-4256_3400-3656_8150-8406.h5",]
44+
label: "~/dropbox/40_gt/13_wasp_sample3/vol_04000/label_v3.h5"
45+
model:
46+
in_channels: 3
47+
out_channels: 3
48+
49+
train:
50+
iter_start: 0
51+
iter_stop: 1000000
52+
class_rebalance: false
53+
# batch size per GPU
54+
# The dataprovider should provide nGPU*batch_size batches!
55+
batch_size: 1
56+
output_dir: "./"
57+
patch_size: [128, 128, 128]
58+
learning_rate: 0.001
59+
#training_interval: 200
60+
#validation_interval: 2000
61+
training_interval: 2
62+
validation_interval: 4

neutorch/data/dataset.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1+
import math
12
import random
23
from functools import cached_property
3-
import math
44

55
import numpy as np
66
import torch
7-
from yacs.config import CfgNode
8-
97
from chunkflow.lib.cartesian_coordinate import Cartesian
8+
from yacs.config import CfgNode
109

11-
from neutorch.data.transform import *
1210
from neutorch.data.sample import SemanticSample
11+
from neutorch.data.transform import *
1312

1413
DEFAULT_PATCH_SIZE = Cartesian(128, 128, 128)
1514

@@ -246,7 +245,6 @@ def __next__(self):
246245

247246
class AffinityMapDataset(SemanticDataset):
248247
def __init__(self, samples: list):
249-
#patch_size: Cartesian = DEFAULT_PATCH_SIZE):
250248
super().__init__(samples)
251249

252250
@cached_property
@@ -267,12 +265,33 @@ def transform(self):
267265
Flip(),
268266
Transpose(),
269267
MissAlignment(),
270-
# Tranform to affinity map
271-
# there is a shrinking, so we put this transformation here
272-
# rather than the label2target function.
273268
Label2AffinityMap(probability=1.),
274269
])
275270

271+
class BoundaryAugmentationDataset(SemanticDataset):
272+
def __initi__(self, samples: list):
273+
super.__init__(samples)
274+
275+
@cached_property
276+
def transform(self):
277+
return Compose([
278+
NormalizeTo01(probability=1.),
279+
AdjustBrightness(),
280+
AdjustContrast(),
281+
Gamma(),
282+
OneOf([
283+
Noise(),
284+
GaussianBlur2D(),
285+
]),
286+
MaskBox(),
287+
Perspective2D(),
288+
# RotateScale(probability=1.),
289+
# DropSection(),
290+
Flip(),
291+
Transpose(),
292+
MissAlignment(),
293+
])
294+
276295
if __name__ == '__main__':
277296

278297
from yacs.config import CfgNode
@@ -282,10 +301,9 @@ def transform(self):
282301
cfg = CfgNode.load_cfg(file)
283302
cfg.freeze()
284303

285-
sd = AffinityMapDataset(
304+
sd = BoundaryAugmentationDataset(
286305
path_list=['/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_01700/rna_v1.h5'],
287306
sample_name_to_image_versions=cfg.dataset.sample_name_to_image_versions,
288307
patch_size=Cartesian(128, 128, 128),
289308
)
290-
291-
# print(sd.samples)
309+

neutorch/data/patch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from functools import cached_property
2-
import numpy as np
32

3+
import numpy as np
44
# from torch import tensor, device
55
import torch
6+
67
# torch.multiprocessing.set_start_method('spawn')
78

89
# from chunkflow.lib.cartesian_coordinate import Cartesian

neutorch/data/transform.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
1-
from abc import ABC, abstractmethod
21
import random
2+
from abc import ABC, abstractmethod
33
from functools import cached_property
44

5-
6-
from chunkflow.lib.cartesian_coordinate import Cartesian
7-
# from copy import deepcopy
8-
5+
import cv2
96
import numpy as np
10-
7+
from chunkflow.lib.cartesian_coordinate import Cartesian
8+
# from reneu.lib.segmentation import seg_to_affs
119
from scipy.ndimage.filters import gaussian_filter
12-
# from scipy.ndimage import affine_transform
13-
14-
import cv2
15-
16-
from skimage.util import random_noise
1710
from skimage.transform import swirl
18-
19-
from reneu.lib.segmentation import seg_to_affs
11+
from skimage.util import random_noise
2012

2113
from .patch import Patch
2214

15+
# from copy import deepcopy
16+
17+
18+
# from scipy.ndimage import affine_transform
2319

2420
DEFAULT_PROBABILITY = .5
2521
DEFAULT_SHRINK_SIZE = (0, 0, 0, 0, 0, 0)

neutorch/train/affinity_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import click
44
from yacs.config import CfgNode
55

6-
from .base import TrainerBase
76
from neutorch.data.dataset import AffinityMapDataset
87

8+
from .base import TrainerBase
9+
910

1011
class AffinityMapTrainer(TrainerBase):
1112
def __init__(self, cfg: CfgNode) -> None:

neutorch/train/base.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
1+
import os
2+
import random
13
from abc import ABC, abstractproperty
24
from functools import cached_property
35
from glob import glob
4-
5-
import random
6-
import os
76
from time import time
87

9-
from yacs.config import CfgNode
108
import numpy as np
11-
12-
from chunkflow.lib.cartesian_coordinate import Cartesian
13-
149
import torch
15-
from torch.utils.tensorboard import SummaryWriter
10+
from chunkflow.lib.cartesian_coordinate import Cartesian
1611
from torch.utils.data import DataLoader
17-
from neutorch.data.patch import collate_batch
12+
from torch.utils.tensorboard import SummaryWriter
13+
from yacs.config import CfgNode
1814

19-
from neutorch.model.IsoRSUNet import Model
20-
from neutorch.model.io import save_chkpt, load_chkpt, log_tensor
21-
from neutorch.loss import BinomialCrossEntropyWithLogits
2215
from neutorch.data.dataset import worker_init_fn
16+
from neutorch.data.patch import collate_batch
17+
from neutorch.loss import BinomialCrossEntropyWithLogits
18+
from neutorch.model.io import load_chkpt, log_tensor, save_chkpt
19+
from neutorch.model.IsoRSUNet import Model
2320

2421

2522
class TrainerBase(ABC):

neutorch/train/boundary_aug.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
from functools import cached_property
3+
from time import time
4+
5+
import click
6+
import numpy as np
7+
import torch
8+
from torch.utils.tensorboard import SummaryWriter
9+
from yacs.config import CfgNode
10+
11+
from neutorch.data.dataset import BoundaryAugmentationDataset
12+
from neutorch.model.io import log_tensor, save_chkpt
13+
14+
from .base import SemanticTrainer, TrainerBase
15+
16+
class BoundaryAugTrainer(SemanticTrainer):
17+
def __init__(self, cfg: CfgNode) -> None:
18+
assert isinstance(cfg, CfgNode)
19+
super().__init__(cfg)
20+
self.cfg = cfg
21+
breakpoint()
22+
23+
@cached_property
24+
def training_dataset(self):
25+
return BoundaryAugmentationDataset.from_config(self.cfg, is_train=True)
26+
27+
@cached_property
28+
def validation_dataset(self):
29+
return BoundaryAugmentationDataset.from_config(self.cfg, is_train=False)
30+
31+
"""
32+
def call(self):
33+
writer = SummaryWriter(log_dir=self.cfg.train.output_dir)
34+
accumulated_loss = 0. #floating point
35+
36+
for image, label in self.training_data_loader:
37+
iter_idx += 1
38+
if iter_idx > self.cfg.train.iter_stop:
39+
print('exeeds maximum iteration:', self.cfg.train.iter_stop)
40+
return
41+
42+
pint = time()
43+
predict = self.model(image)
44+
loss = self.loss_module(predict, label)
45+
assert not torch.isnan(loss), 'loss is NaN.'
46+
47+
self.optimizer #
48+
loss.backward()
49+
self.optimizer.step()
50+
accumulated_loss += loss.tolist()
51+
52+
if iter_idx % self.cfg.train.training_interval == 0 and iter_idx > 0:
53+
per_voxel_loss = accumulated_loss / \
54+
self.cfg.train.training_interval / \
55+
self.voxel_num
56+
57+
print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}')
58+
accumulated_loss = 0.
59+
predict = self.post_processing(predict)
60+
writer.add_scalar('loss/train', per_voxel_loss, iter_idx)
61+
log_tensor(writer, 'train/image', image, 'image', iter_idx)
62+
log_tensor(writer, 'train/prediction', predict.detach(), 'image', iter_idx)
63+
log_tensor(writer, 'train/label', label, 'segmentation', iter_idx)
64+
65+
if iter_idx % self.cfg.train.validation_interval == 0 and iter_idx > 0:
66+
fname = os.path.join(self.cfg.train.output_dir, f'model_{iter_idx}.chkpt')
67+
print(f'save model to {fname}')
68+
save_chkpt(self.model, self.cfg.train.output_dir, iter_idx, self.optimizer)
69+
70+
print('evaluate prediction: ')
71+
validation_image, validation_label = next(self.validation_data_iter)
72+
73+
with torch.no_grad():
74+
validation_predict = self.model(validation_image)
75+
validation_loss = self.loss_module(validation_predict, validation_label)
76+
validation_predict = self.post_processing(validation_predict)
77+
per_voxel_loss = validation_loss.tolist() / self.voxel_num
78+
print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}')
79+
writer.add_scalar('loss/validation', per_voxel_loss, iter_idx)
80+
log_tensor(writer, 'validation/image', validation_image, 'image', iter_idx)
81+
log_tensor(writer, 'validation/prediction', validation_predict, 'image', iter_idx)
82+
log_tensor(writer, 'validation/label', validation_label, 'segmentation', iter_idx)
83+
84+
writer.close()
85+
"""
86+
87+
@click.command()
88+
@click.option('--config-file', '-c',
89+
type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True, resolve_path=True),
90+
default='./config.yaml',
91+
help = 'configuration file containing all the parameters.'
92+
)
93+
94+
def main(config_file: str):
95+
from neutorch.data.dataset import load_cfg
96+
cfg = load_cfg(config_file)
97+
trainer = BoundaryAugTrainer(cfg)
98+
trainer()
99+
100+

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
neutrain-denoise=neutorch.train.denoise:main
1919
neutrain-post=neutorch.train.post_synapses:main
2020
neutrain-affs=neutorch.train.affinity_map:main
21+
neutrain-ba=neutorch.train.boundary_aug:main
2122
''',
2223
classifiers=[
2324
'Development Status :: 4 - Beta',

0 commit comments

Comments
 (0)