@@ -78,112 +78,112 @@ def test_create_fft_sampled():
78
78
assert result .exit_code == 0
79
79
80
80
81
- def test_normalization ():
82
- from mnist_cnn .scripts .calculate_normalization import main
83
- import pandas as pd
84
- from dl_framework .data import get_bundles , open_fft_pair , do_normalisation
85
- import re
86
- import torch
87
-
88
- data_path = "./tests/build/mnist"
89
- out_path = "./tests/build/mnist/normalization_factors.csv"
90
-
91
- runner = CliRunner ()
92
- options = [data_path , out_path ]
93
- result = runner .invoke (main , options )
94
- print (traceback .print_exception (* result .exc_info ))
95
-
96
- assert result .exit_code == 0
97
-
98
- factors = pd .read_csv (out_path )
99
-
100
- assert (
101
- factors .keys ()
102
- == ["train_mean_c0" , "train_std_c0" , "train_mean_c1" , "train_std_c1" , ]
103
- ).all ()
104
- assert ~ np .isnan (factors .values ).all ()
105
- assert ~ np .isinf (factors .values ).all ()
106
- assert (factors .values != 0 ).all ()
107
-
108
- bundle_paths = get_bundles (data_path )
109
- bundle_paths = [
110
- path for path in bundle_paths if re .findall ("fft_bundle_samp_train" , path .name )
111
- ]
112
-
113
- bundles = [open_fft_pair (bund ) for bund in bundle_paths ]
114
-
115
- a = np .stack ((bundles [0 ][0 ][:, 0 ], bundles [0 ][0 ][:, 1 ]), axis = 1 )
116
-
117
- assert np .isclose (do_normalisation (torch .tensor (a ), factors ).mean (), 0 , atol = 1e-1 )
118
- assert np .isclose (do_normalisation (torch .tensor (a ), factors ).std (), 1 , atol = 1e-1 )
119
-
120
-
121
- def test_train_cnn ():
122
- from mnist_cnn .scripts .train_cnn import main
123
-
124
- data_path = "./tests/build/mnist"
125
- path_model = "./tests/build/mnist/test.model"
126
- arch = "UNet_denoise"
127
- norm_path = "./tests/build/mnist/normalization_factors.csv"
128
- epochs = "5"
129
- lr = "1e-3"
130
- lr_type = "mse"
131
- bs = "2"
132
-
133
- runner = CliRunner ()
134
- options = [
135
- data_path ,
136
- path_model ,
137
- arch ,
138
- norm_path ,
139
- epochs ,
140
- lr ,
141
- lr_type ,
142
- bs ,
143
- "-fourier" ,
144
- False ,
145
- "-pretrained" ,
146
- False ,
147
- "-inspection" ,
148
- False ,
149
- "-test" ,
150
- True ,
151
- ]
152
- result = runner .invoke (main , options )
153
- print (traceback .print_exception (* result .exc_info ))
154
-
155
- assert result .exit_code == 0
156
-
157
-
158
- def test_find_lr ():
159
- from mnist_cnn .scripts .find_lr import main
160
-
161
- data_path = "./tests/build/mnist"
162
- arch = "UNet_denoise"
163
- norm_path = "./tests/build/mnist/normalization_factors.csv"
164
- lr_type = "mse"
165
-
166
- runner = CliRunner ()
167
- options = [
168
- data_path ,
169
- arch ,
170
- data_path ,
171
- lr_type ,
172
- norm_path ,
173
- "-max_iter" ,
174
- "400" ,
175
- "-min_lr" ,
176
- "1e-6" ,
177
- "-max_lr" ,
178
- "1e-1" ,
179
- "-fourier" ,
180
- False ,
181
- "-pretrained" ,
182
- False ,
183
- "-save" ,
184
- True ,
185
- ]
186
- result = runner .invoke (main , options )
187
- print (traceback .print_exception (* result .exc_info ))
188
-
189
- assert result .exit_code == 0
81
+ # def test_normalization():
82
+ # from mnist_cnn.scripts.calculate_normalization import main
83
+ # import pandas as pd
84
+ # from dl_framework.data import get_bundles, open_fft_pair, do_normalisation
85
+ # import re
86
+ # import torch
87
+
88
+ # data_path = "./tests/build/mnist"
89
+ # out_path = "./tests/build/mnist/normalization_factors.csv"
90
+
91
+ # runner = CliRunner()
92
+ # options = [data_path, out_path]
93
+ # result = runner.invoke(main, options)
94
+ # print(traceback.print_exception(*result.exc_info))
95
+
96
+ # assert result.exit_code == 0
97
+
98
+ # factors = pd.read_csv(out_path)
99
+
100
+ # assert (
101
+ # factors.keys()
102
+ # == ["train_mean_c0", "train_std_c0", "train_mean_c1", "train_std_c1", ]
103
+ # ).all()
104
+ # assert ~np.isnan(factors.values).all()
105
+ # assert ~np.isinf(factors.values).all()
106
+ # assert (factors.values != 0).all()
107
+
108
+ # bundle_paths = get_bundles(data_path)
109
+ # bundle_paths = [
110
+ # path for path in bundle_paths if re.findall("fft_bundle_samp_train", path.name)
111
+ # ]
112
+
113
+ # bundles = [open_fft_pair(bund) for bund in bundle_paths]
114
+
115
+ # a = np.stack((bundles[0][0][:, 0], bundles[0][0][:, 1]), axis=1)
116
+
117
+ # assert np.isclose(do_normalisation(torch.tensor(a), factors).mean(), 0, atol=1e-1)
118
+ # assert np.isclose(do_normalisation(torch.tensor(a), factors).std(), 1, atol=1e-1)
119
+
120
+
121
+ # def test_train_cnn():
122
+ # from mnist_cnn.scripts.train_cnn import main
123
+
124
+ # data_path = "./tests/build/mnist"
125
+ # path_model = "./tests/build/mnist/test.model"
126
+ # arch = "UNet_denoise"
127
+ # norm_path = "./tests/build/mnist/normalization_factors.csv"
128
+ # epochs = "5"
129
+ # lr = "1e-3"
130
+ # lr_type = "mse"
131
+ # bs = "2"
132
+
133
+ # runner = CliRunner()
134
+ # options = [
135
+ # data_path,
136
+ # path_model,
137
+ # arch,
138
+ # norm_path,
139
+ # epochs,
140
+ # lr,
141
+ # lr_type,
142
+ # bs,
143
+ # "-fourier",
144
+ # False,
145
+ # "-pretrained",
146
+ # False,
147
+ # "-inspection",
148
+ # False,
149
+ # "-test",
150
+ # True,
151
+ # ]
152
+ # result = runner.invoke(main, options)
153
+ # print(traceback.print_exception(*result.exc_info))
154
+
155
+ # assert result.exit_code == 0
156
+
157
+
158
+ # def test_find_lr():
159
+ # from mnist_cnn.scripts.find_lr import main
160
+
161
+ # data_path = "./tests/build/mnist"
162
+ # arch = "UNet_denoise"
163
+ # norm_path = "./tests/build/mnist/normalization_factors.csv"
164
+ # lr_type = "mse"
165
+
166
+ # runner = CliRunner()
167
+ # options = [
168
+ # data_path,
169
+ # arch,
170
+ # data_path,
171
+ # lr_type,
172
+ # norm_path,
173
+ # "-max_iter",
174
+ # "400",
175
+ # "-min_lr",
176
+ # "1e-6",
177
+ # "-max_lr",
178
+ # "1e-1",
179
+ # "-fourier",
180
+ # False,
181
+ # "-pretrained",
182
+ # False,
183
+ # "-save",
184
+ # True,
185
+ # ]
186
+ # result = runner.invoke(main, options)
187
+ # print(traceback.print_exception(*result.exc_info))
188
+
189
+ # assert result.exit_code == 0
0 commit comments