Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion olmoearth_pretrain/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def _fill_missing_modality(
) -> OlmoEarthSample:
"""Fill an array of shape of modality with the missing value."""
expected_shape = sample.get_expected_shape(modality)
logger.info(f"Filling {modality} with shape {expected_shape}")
logger.debug(f"Filling {modality} with shape {expected_shape}")
return np.full(
expected_shape,
fill_value=MISSING_VALUE,
Expand Down
3 changes: 2 additions & 1 deletion olmoearth_pretrain/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def launch(config: OlmoEarthExperimentConfig) -> None:
logger.info("Launching the experiment")
logger.info(config)
# Set follow=False if you don't want to stream the logs to the terminal
config.launch.launch(follow=False)
# Default to enabling torchrun so we can run multi gpu scripts on single gpu
config.launch.launch(follow=False, torchrun=True)


def prep(config: OlmoEarthExperimentConfig) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def model_forward(
if extra_metrics is not None:
self.log_extra_metrics(extra_metrics)
with torch.no_grad():
logger.info("Target Encoder forward pass...")
logger.debug("Target Encoder forward pass...")
output_dict = self.model.target_encoder.forward(
batch.unmask(),
patch_size=patch_size,
Expand Down
4 changes: 3 additions & 1 deletion olmoearth_pretrain/train/train_module/train_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def state_dict(self) -> dict[str, Any]:
"""Get the state dict."""
return self._get_state_dict(self.state_dict_save_opts)

def state_dict_to_load(self, metadata: Metadata) -> dict[str, Any]:
def state_dict_to_load(
self, metadata: Metadata, optim: bool | None = None
) -> dict[str, Any]:
Comment thread
Hgherzog marked this conversation as resolved.
"""Get the state dict to load."""
load_opts = self.state_dict_load_opts
return self._get_state_dict(load_opts)
Expand Down
1 change: 1 addition & 0 deletions requirements-beaker.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
beaker-py==1.34.1
docker>=5.0,<8.0
google-cloud-compute
packaging
pydantic>=1.8.2,<3.0
PyYAML
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ai2-olmo-core @ git+https://github.com/allenai/OLMo-core.git@abc12e50ba756c21e575452cfc6f150dafa9509e # Pin here until >2.1.0 is released.
ai2-olmo-core==2.3.0
albumentations
cartopy
class-registry
Expand Down