|
2 | 2 |
|
3 | 3 | import keras.backend as K |
4 | 4 |
|
| 5 | +from csbdeep.models import BaseConfig |
5 | 6 | from csbdeep.utils import _raise, axes_check_and_normalize, axes_dict, backend_channels_last |
6 | 7 |
|
7 | 8 | from six import string_types |
8 | 9 |
|
9 | 10 | import numpy as np |
| 11 | +from logging.config import BaseConfigurator |
10 | 12 |
|
11 | 13 | # This class is a adapted version of csbdeep.models.config.py. |
12 | 14 | class N2VConfig(argparse.Namespace): |
@@ -68,83 +70,87 @@ class N2VConfig(argparse.Namespace): |
68 | 70 | """ |
69 | 71 |
|
70 | 72 | def __init__(self, X, **kwargs): |
71 | | - """See class docstring.""" |
72 | | - |
73 | | - assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." |
74 | | - |
75 | | - n_dim = len(X.shape) - 2 |
76 | | - n_channel_in = X.shape[-1] |
77 | | - n_channel_out = n_channel_in |
78 | | - mean = np.mean(X) |
79 | | - std = np.std(X) |
80 | | - |
81 | | - if n_dim == 2: |
82 | | - axes = 'SYXC' |
83 | | - elif n_dim == 3: |
84 | | - axes = 'SZYXC' |
85 | | - |
86 | | - # parse and check axes |
87 | | - axes = axes_check_and_normalize(axes) |
88 | | - ax = axes_dict(axes) |
89 | | - ax = {a: (ax[a] is not None) for a in ax} |
90 | | - |
91 | | - (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) |
92 | | - not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) |
93 | | - |
94 | | - axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) |
95 | | - axes = axes.replace('S','') # remove sample axis if it exists |
96 | | - |
97 | | - if backend_channels_last(): |
98 | | - if ax['C']: |
99 | | - axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) |
| 73 | + |
| 74 | + # X is empty if config is None |
| 75 | + if (X.size != 0): |
| 76 | + |
| 77 | + assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." |
| 78 | + |
| 79 | + n_dim = len(X.shape) - 2 |
| 80 | + n_channel_in = X.shape[-1] |
| 81 | + n_channel_out = n_channel_in |
| 82 | + mean = np.mean(X) |
| 83 | + std = np.std(X) |
| 84 | + |
| 85 | + if n_dim == 2: |
| 86 | + axes = 'SYXC' |
| 87 | + elif n_dim == 3: |
| 88 | + axes = 'SZYXC' |
| 89 | + |
| 90 | + # parse and check axes |
| 91 | + axes = axes_check_and_normalize(axes) |
| 92 | + ax = axes_dict(axes) |
| 93 | + ax = {a: (ax[a] is not None) for a in ax} |
| 94 | + |
| 95 | + (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) |
| 96 | + not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) |
| 97 | + |
| 98 | + axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) |
| 99 | + axes = axes.replace('S','') # remove sample axis if it exists |
| 100 | + |
| 101 | + if backend_channels_last(): |
| 102 | + if ax['C']: |
| 103 | + axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) |
| 104 | + else: |
| 105 | + axes += 'C' |
100 | 106 | else: |
101 | | - axes += 'C' |
102 | | - else: |
103 | | - if ax['C']: |
104 | | - axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) |
| 107 | + if ax['C']: |
| 108 | + axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) |
| 109 | + else: |
| 110 | + axes = 'C'+axes |
| 111 | + |
| 112 | + # normalization parameters |
| 113 | + self.mean = str(mean) |
| 114 | + self.std = str(std) |
| 115 | + # directly set by parameters |
| 116 | + self.n_dim = n_dim |
| 117 | + self.axes = axes |
| 118 | + self.n_channel_in = int(n_channel_in) |
| 119 | + self.n_channel_out = int(n_channel_out) |
| 120 | + |
| 121 | + # default config (can be overwritten by kwargs below) |
| 122 | + self.unet_residual = False |
| 123 | + self.unet_n_depth = 2 |
| 124 | + self.unet_kern_size = 5 if self.n_dim==2 else 3 |
| 125 | + self.unet_n_first = 32 |
| 126 | + self.unet_last_activation = 'linear' |
| 127 | + if backend_channels_last(): |
| 128 | + self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,) |
105 | 129 | else: |
106 | | - axes = 'C'+axes |
107 | | - |
108 | | - # normalization parameters |
109 | | - self.mean = str(mean) |
110 | | - self.std = str(std) |
111 | | - # directly set by parameters |
112 | | - self.n_dim = n_dim |
113 | | - self.axes = axes |
114 | | - self.n_channel_in = int(n_channel_in) |
115 | | - self.n_channel_out = int(n_channel_out) |
116 | | - |
117 | | - # default config (can be overwritten by kwargs below) |
118 | | - self.unet_residual = False |
119 | | - self.unet_n_depth = 2 |
120 | | - self.unet_kern_size = 5 if self.n_dim==2 else 3 |
121 | | - self.unet_n_first = 32 |
122 | | - self.unet_last_activation = 'linear' |
123 | | - if backend_channels_last(): |
124 | | - self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,) |
125 | | - else: |
126 | | - self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,) |
127 | | - |
128 | | - self.train_loss = 'mae' |
129 | | - self.train_epochs = 100 |
130 | | - self.train_steps_per_epoch = 400 |
131 | | - self.train_learning_rate = 0.0004 |
132 | | - self.train_batch_size = 16 |
133 | | - self.train_tensorboard = True |
134 | | - self.train_checkpoint = 'weights_best.h5' |
135 | | - self.train_reduce_lr = {'factor': 0.5, 'patience': 10} |
136 | | - self.batch_norm = True |
137 | | - self.n2v_perc_pix = 1.5 |
138 | | - self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64) |
139 | | - self.n2v_manipulator = 'uniform_withCP' |
140 | | - self.n2v_neighborhood_radius = 5 |
141 | | - |
142 | | - # disallow setting 'n_dim' manually |
143 | | - try: |
144 | | - del kwargs['n_dim'] |
145 | | - # warnings.warn("ignoring parameter 'n_dim'") |
146 | | - except: |
147 | | - pass |
| 130 | + self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,) |
| 131 | + |
| 132 | + self.train_loss = 'mae' |
| 133 | + self.train_epochs = 100 |
| 134 | + self.train_steps_per_epoch = 400 |
| 135 | + self.train_learning_rate = 0.0004 |
| 136 | + self.train_batch_size = 16 |
| 137 | + self.train_tensorboard = True |
| 138 | + self.train_checkpoint = 'weights_best.h5' |
| 139 | + self.train_reduce_lr = {'factor': 0.5, 'patience': 10} |
| 140 | + self.batch_norm = True |
| 141 | + self.n2v_perc_pix = 1.5 |
| 142 | + self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64) |
| 143 | + self.n2v_manipulator = 'uniform_withCP' |
| 144 | + self.n2v_neighborhood_radius = 5 |
| 145 | + |
| 146 | + # disallow setting 'n_dim' manually |
| 147 | + try: |
| 148 | + del kwargs['n_dim'] |
| 149 | + # warnings.warn("ignoring parameter 'n_dim'") |
| 150 | + except: |
| 151 | + pass |
| 152 | + |
| 153 | + self.probabilistic = False |
148 | 154 |
|
149 | 155 | for k in kwargs: |
150 | 156 | setattr(self, k, kwargs[k]) |
@@ -215,3 +221,16 @@ def _is_int(v,low=None,high=None): |
215 | 221 | return all(ok.values()), tuple(k for (k,v) in ok.items() if not v) |
216 | 222 | else: |
217 | 223 | return all(ok.values()) |
| 224 | + |
| 225 | + def update_parameters(self, allow_new=True, **kwargs): |
| 226 | + if not allow_new: |
| 227 | + attr_new = [] |
| 228 | + for k in kwargs: |
| 229 | + try: |
| 230 | + getattr(self, k) |
| 231 | + except AttributeError: |
| 232 | + attr_new.append(k) |
| 233 | + if len(attr_new) > 0: |
| 234 | + raise AttributeError("Not allowed to add new parameters (%s)" % ', '.join(attr_new)) |
| 235 | + for k in kwargs: |
| 236 | + setattr(self, k, kwargs[k]) |
0 commit comments