Skip to content

Commit 4f426fe

Browse files
Make matched fixes to invalid metric logging.
1 parent a50936c commit 4f426fe

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,10 @@ def train_model(
278278
)
279279
# Configure the model
280280
train_metric = None
281-
if task_type == "regression":
281+
is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
282+
if is_model_parallel:
283+
valid_metric = None # metric logging under model parallelism is not supported yet
284+
elif task_type == "regression":
282285
valid_metric = TorchmetricsConfig(class_path="MeanSquaredError", task="regression", metric_name="val_mse")
283286
else:
284287
valid_metric = TorchmetricsConfig(
@@ -292,11 +295,6 @@ def train_model(
292295
metric_name="val_acc",
293296
)
294297

295-
if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
296-
train_metric is not None or valid_metric is not None
297-
):
298-
raise NotImplementedError("Metric logging under model parallelism is not supported yet.")
299-
300298
config = config_class(
301299
task_type=task_type,
302300
encoder_frozen=encoder_frozen,

sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,16 @@ def main(
265265
)
266266
# Configure the model
267267
train_metric = None
268-
valid_metric = TorchmetricsConfig(
269-
class_path="text.Perplexity",
270-
task="pretraining",
271-
kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
272-
metric_name="val_ppl",
273-
)
274-
if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
275-
train_metric is not None or valid_metric is not None
276-
):
277-
raise NotImplementedError("Metric logging under model parallelism is not supported yet.")
278-
268+
is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
269+
if is_model_parallel:
270+
valid_metric = None # metric logging under model parallelism is not supported yet
271+
else:
272+
valid_metric = TorchmetricsConfig(
273+
class_path="text.Perplexity",
274+
task="pretraining",
275+
kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
276+
metric_name="val_ppl",
277+
)
279278
esm2_config = ESM2Config(
280279
seq_length=max_seq_length,
281280
num_layers=num_layers,

0 commit comments

Comments
 (0)