Skip to content

Commit e23d572

Browse files
committed
simplify as we are required to have it for all
1 parent 86605fa commit e23d572

1 file changed

Lines changed: 5 additions & 12 deletions

File tree

olmoearth_pretrain/internal/experiment.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
from collections.abc import Callable
66
from dataclasses import dataclass
7-
from typing import cast, Optional, Any
7+
from typing import cast
88

99
import numpy as np
1010
from olmo_core.config import Config, StrEnum
@@ -20,6 +20,7 @@
2020
from olmoearth_pretrain._compat import deprecated_class_alias as _deprecated_class_alias
2121
from olmoearth_pretrain.data.constants import Modality
2222
from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig
23+
from olmoearth_pretrain.launch.beaker import OlmoEarthBeakerLaunchConfig
2324
from olmoearth_pretrain.data.dataset import (
2425
OlmoEarthDatasetConfig,
2526
collate_olmoearth_pretrain,
@@ -34,22 +35,14 @@
3435

3536
logger = logging.getLogger(__name__)
3637

37-
# Goal is to make the main script runnable without beaker installed
38-
try:
39-
from olmoearth_pretrain.launch.beaker import OlmoEarthBeakerLaunchConfig
40-
LAUNCH_CONFIG_TYPE = Optional[OlmoEarthBeakerLaunchConfig]
41-
except ImportError:
42-
logger.warning("Beaker launch config not available, please install beaker to use it")
43-
LAUNCH_CONFIG_TYPE = Optional[Any]
44-
4538
@dataclass
4639
class CommonComponents(Config):
4740
"""Any configurable items that are common to all experiments."""
4841

4942
run_name: str
5043
save_folder: str
5144
training_modalities: list[str]
52-
launch: LAUNCH_CONFIG_TYPE = None
45+
launch: OlmoEarthBeakerLaunchConfig | None = None
5346
nccl_debug: bool = False
5447
# callbacks: dict[str, Callback]
5548

@@ -91,7 +84,7 @@ class OlmoEarthExperimentConfig(Config):
9184
data_loader: OlmoEarthDataLoaderConfig # will likely be fixed for us
9285
train_module: OlmoEarthTrainModuleConfig
9386
trainer: TrainerConfig
94-
launch: LAUNCH_CONFIG_TYPE = None
87+
launch: OlmoEarthBeakerLaunchConfig | None = None
9588
visualize: OlmoEarthVisualizeConfig | None = None
9689
init_seed: int = 12536
9790

@@ -107,7 +100,7 @@ class BenchmarkExperimentConfig(Config):
107100
"""Configuration for a throughput benchmarking run."""
108101

109102
benchmark: ThroughputBenchmarkRunnerConfig
110-
launch: LAUNCH_CONFIG_TYPE = None
103+
launch: OlmoEarthBeakerLaunchConfig | None = None
111104

112105

113106
def split_common_overrides(overrides: list[str]) -> tuple[list[str], list[str]]:

0 commit comments

Comments
 (0)