Skip to content

Commit 23d9399

Browse files
Merge pull request #80 from ArnovanHilten/dev
Add multiprocessing option
2 parents 76b69d2 + 6d50131 commit 23d9399

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

GenNet.py

+6
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ def make_parser_train(self, parser_train):
167167
metavar="number of epochs",
168168
default=1000,
169169
help='Hyperparameter: batch size')
170+
parser_train.add_argument(
171+
"-workers",
172+
type=int,
173+
metavar="number of workers for multiprocessing",
174+
default=1,
175+
help='Speed-up: number of workers (CPU cores) for multiprocessing. Can cause memory-leaks in some tensorflow versions')
170176
parser_train.add_argument(
171177
"-L1",
172178
metavar="",

GenNet_utils/Train_network.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ def train_classification(args):
4747

4848
if args.mixed_precision == True:
4949
use_mixed_precision()
50-
50+
51+
if args.workers > 1:
52+
multiprocessing = True
53+
else:
54+
multiprocessing = False
55+
5156
check_data(datapath=datapath, genotype_path=genotype_path, mode=problem_type)
5257

5358
global weight_positive_class, weight_negative_class
@@ -131,8 +136,8 @@ def train_classification(args):
131136
epochs=epochs,
132137
verbose=1,
133138
callbacks=[early_stop, save_best_model, csv_logger],
134-
workers=1,
135-
use_multiprocessing=False,
139+
workers=args.workers,
140+
use_multiprocessing=multiprocessing,
136141
validation_data=EvalGenerator(datapath=datapath, genotype_path=genotype_path, batch_size=batch_size,
137142
setsize=val_size_train,
138143
inputsize=inputsize, evalset="validation")
@@ -152,8 +157,8 @@ def train_classification(args):
152157
epochs=epochs,
153158
verbose=1,
154159
callbacks=[early_stop, save_best_model, csv_logger],
155-
workers=1,
156-
use_multiprocessing=False,
160+
workers=args.workers,
161+
use_multiprocessing=multiprocessing,
157162
validation_data=EvalGenerator(datapath=datapath, genotype_path=genotype_path, batch_size=batch_size,
158163
setsize=val_size_train,
159164
inputsize=inputsize, evalset="validation")
@@ -221,13 +226,20 @@ def train_regression(args):
221226
l1_value = args.L1
222227
problem_type = args.problem_type
223228
patience = args.patience
224-
225229

226230
if args.genotype_path == "undefined":
227231
genotype_path = datapath
228232
else:
229233
genotype_path = args.genotype_path
230-
234+
235+
if args.mixed_precision == True:
236+
use_mixed_precision()
237+
238+
if args.workers > 1:
239+
multiprocessing = True
240+
else:
241+
multiprocessing = False
242+
231243
check_data(datapath=datapath, genotype_path=genotype_path, mode=problem_type)
232244

233245
optimizer_model = tf.keras.optimizers.Adam(lr=lr_opt)
@@ -306,8 +318,8 @@ def train_regression(args):
306318
epochs=epochs,
307319
verbose=1,
308320
callbacks=[early_stop, save_best_model, csv_logger],
309-
workers=1,
310-
use_multiprocessing=False,
321+
workers=args.workers,
322+
use_multiprocessing=multiprocessing,
311323
validation_data=EvalGenerator(datapath=datapath, genotype_path=genotype_path, batch_size=batch_size,
312324
setsize=val_size_train, inputsize=inputsize, evalset="validation")
313325
)
@@ -325,8 +337,8 @@ def train_regression(args):
325337
epochs=epochs,
326338
verbose=1,
327339
callbacks=[early_stop, save_best_model, csv_logger],
328-
workers=1,
329-
use_multiprocessing=False,
340+
workers=args.workers,
341+
use_multiprocessing=multiprocessing,
330342
validation_data=EvalGenerator(datapath=datapath, genotype_path=genotype_path, batch_size=batch_size,
331343
setsize=val_size_train, inputsize=inputsize, evalset="validation")
332344
)

0 commit comments

Comments
 (0)