@@ -20,6 +20,10 @@ def create_python_dataset_reader(args):
2020 readers = []
2121 for role in ['train' , 'val' , 'test' ]:
2222 role_dir = getattr (args , f'{ role } _dir' )
23+ if not role_dir :
24+ continue
25+ if role == 'val' :
26+ role = 'validate'
2327 dataset = CosmoFlowDataset (role_dir , args .input_width , args .num_secrets )
2428 reader = lbann .util .data .construct_python_dataset_reader (dataset , role = role )
2529 readers .append (reader )
@@ -35,11 +39,13 @@ def create_cosmoflow_data_reader(
3539 num_responses (int): The number of parameters to predict.
3640 """
3741
38- reader_args = [
39- {"role" : "train" , "data_filename" : train_path },
40- {"role" : "validate" , "data_filename" : val_path },
41- # {"role": "test", "data_filename": test_path},
42- ]
42+ reader_args = []
43+ if train_path :
44+ reader_args .append ({"role" : "train" , "data_filename" : train_path })
45+ if val_path :
46+ reader_args .append ({"role" : "validate" , "data_filename" : val_path })
47+ if test_path :
48+ reader_args .append ({"role" : "test" , "data_filename" : test_path })
4349
4450 for reader_arg in reader_args :
4551 reader_arg ["data_file_pattern" ] = "{}/*.hdf5" .format (
@@ -142,7 +148,7 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
142148 default_dir = '{}/{}' .format (default_lc_dataset , role )
143149 parser .add_argument (
144150 '--{}-dir' .format (role ), action = 'store' , type = str ,
145- default = default_dir ,
151+ default = default_dir if role == 'train' else None ,
146152 help = 'the directory of the {} dataset' .format (role ))
147153 parser .add_argument (
148154 '--synthetic' , action = 'store_true' ,
@@ -156,6 +162,9 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
156162 parser .add_argument (
157163 '--transform-input' , action = 'store_true' ,
158164 help = 'Apply log1p transformation to model inputs' )
165+ parser .add_argument (
166+ '--dropout-keep-prob' , action = 'store' , type = float , default = 0.5 ,
167+ help = 'Probability of keeping activations in dropout layers (default: 0.5). Set to 1 to disable dropout' )
159168
160169 # Parallelism arguments
161170 parser .add_argument (
@@ -227,7 +236,8 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
227236 learning_rate = args .optimizer_learning_rate ,
228237 min_distconv_width = args .min_distconv_width ,
229238 mlperf = args .mlperf ,
230- transform_input = args .transform_input )
239+ transform_input = args .transform_input ,
240+ dropout_keep_prob = args .dropout_keep_prob )
231241
232242 # Add profiling callbacks if needed.
233243 model .callbacks .extend (lbann .contrib .args .create_profile_callbacks (args ))
@@ -274,7 +284,7 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
274284 environment ['DISTCONV_JIT_CACHEPATH' ] = f'{ application_path } /DaCe_kernels/.dacecache'
275285
276286 if args .synthetic or args .no_datastore :
277- lbann_args = []
287+ lbann_args = ['--num_io_threads=8' ]
278288 else :
279289 lbann_args = ['--use_data_store' ]
280290 lbann_args += lbann .contrib .args .get_profile_args (args )
0 commit comments