@@ -47,7 +47,12 @@ def train_classification(args):
47
47
48
48
if args .mixed_precision == True :
49
49
use_mixed_precision ()
50
-
50
+
51
+ if args .workers > 1 :
52
+ multiprocessing = True
53
+ else :
54
+ multiprocessing = False
55
+
51
56
check_data (datapath = datapath , genotype_path = genotype_path , mode = problem_type )
52
57
53
58
global weight_positive_class , weight_negative_class
@@ -131,8 +136,8 @@ def train_classification(args):
131
136
epochs = epochs ,
132
137
verbose = 1 ,
133
138
callbacks = [early_stop , save_best_model , csv_logger ],
134
- workers = 1 ,
135
- use_multiprocessing = False ,
139
+ workers = args . workers ,
140
+ use_multiprocessing = multiprocessing ,
136
141
validation_data = EvalGenerator (datapath = datapath , genotype_path = genotype_path , batch_size = batch_size ,
137
142
setsize = val_size_train ,
138
143
inputsize = inputsize , evalset = "validation" )
@@ -152,8 +157,8 @@ def train_classification(args):
152
157
epochs = epochs ,
153
158
verbose = 1 ,
154
159
callbacks = [early_stop , save_best_model , csv_logger ],
155
- workers = 1 ,
156
- use_multiprocessing = False ,
160
+ workers = args . workers ,
161
+ use_multiprocessing = multiprocessing ,
157
162
validation_data = EvalGenerator (datapath = datapath , genotype_path = genotype_path , batch_size = batch_size ,
158
163
setsize = val_size_train ,
159
164
inputsize = inputsize , evalset = "validation" )
@@ -221,13 +226,20 @@ def train_regression(args):
221
226
l1_value = args .L1
222
227
problem_type = args .problem_type
223
228
patience = args .patience
224
-
225
229
226
230
if args .genotype_path == "undefined" :
227
231
genotype_path = datapath
228
232
else :
229
233
genotype_path = args .genotype_path
230
-
234
+
235
+ if args .mixed_precision == True :
236
+ use_mixed_precision ()
237
+
238
+ if args .workers > 1 :
239
+ multiprocessing = True
240
+ else :
241
+ multiprocessing = False
242
+
231
243
check_data (datapath = datapath , genotype_path = genotype_path , mode = problem_type )
232
244
233
245
optimizer_model = tf .keras .optimizers .Adam (lr = lr_opt )
@@ -306,8 +318,8 @@ def train_regression(args):
306
318
epochs = epochs ,
307
319
verbose = 1 ,
308
320
callbacks = [early_stop , save_best_model , csv_logger ],
309
- workers = 1 ,
310
- use_multiprocessing = False ,
321
+ workers = args . workers ,
322
+ use_multiprocessing = multiprocessing ,
311
323
validation_data = EvalGenerator (datapath = datapath , genotype_path = genotype_path , batch_size = batch_size ,
312
324
setsize = val_size_train , inputsize = inputsize , evalset = "validation" )
313
325
)
@@ -325,8 +337,8 @@ def train_regression(args):
325
337
epochs = epochs ,
326
338
verbose = 1 ,
327
339
callbacks = [early_stop , save_best_model , csv_logger ],
328
- workers = 1 ,
329
- use_multiprocessing = False ,
340
+ workers = args . workers ,
341
+ use_multiprocessing = multiprocessing ,
330
342
validation_data = EvalGenerator (datapath = datapath , genotype_path = genotype_path , batch_size = batch_size ,
331
343
setsize = val_size_train , inputsize = inputsize , evalset = "validation" )
332
344
)
0 commit comments