File tree Expand file tree Collapse file tree 2 files changed +14
-17
lines changed
sub-packages/bionemo-esm2/src/bionemo/esm2/scripts Expand file tree Collapse file tree 2 files changed +14
-17
lines changed Original file line number Diff line number Diff line change @@ -278,7 +278,10 @@ def train_model(
278278 )
279279 # Configure the model
280280 train_metric = None
281- if task_type == "regression" :
281+ is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
282+ if is_model_parallel :
283+ valid_metric = None # metric logging under model parallelism is not supported yet
284+ elif task_type == "regression" :
282285 valid_metric = TorchmetricsConfig (class_path = "MeanSquaredError" , task = "regression" , metric_name = "val_mse" )
283286 else :
284287 valid_metric = TorchmetricsConfig (
@@ -292,11 +295,6 @@ def train_model(
292295 metric_name = "val_acc" ,
293296 )
294297
295- if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
296- train_metric is not None or valid_metric is not None
297- ):
298- raise NotImplementedError ("Metric logging under model parallelism is not supported yet." )
299-
300298 config = config_class (
301299 task_type = task_type ,
302300 encoder_frozen = encoder_frozen ,
Original file line number Diff line number Diff line change @@ -265,17 +265,16 @@ def main(
265265 )
266266 # Configure the model
267267 train_metric = None
268- valid_metric = TorchmetricsConfig (
269- class_path = "text.Perplexity" ,
270- task = "pretraining" ,
271- kwargs = {"ignore_index" : MLM_LOSS_IGNORE_INDEX },
272- metric_name = "val_ppl" ,
273- )
274- if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
275- train_metric is not None or valid_metric is not None
276- ):
277- raise NotImplementedError ("Metric logging under model parallelism is not supported yet." )
278-
268+ is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
269+ if is_model_parallel :
270+ valid_metric = None # metric logging under model parallelism is not supported yet
271+ else :
272+ valid_metric = TorchmetricsConfig (
273+ class_path = "text.Perplexity" ,
274+ task = "pretraining" ,
275+ kwargs = {"ignore_index" : MLM_LOSS_IGNORE_INDEX },
276+ metric_name = "val_ppl" ,
277+ )
279278 esm2_config = ESM2Config (
280279 seq_length = max_seq_length ,
281280 num_layers = num_layers ,
You can’t perform that action at this time.
0 commit comments