Skip to content

Commit 41a0325

Browse files
author
piperwolters
committed
run pre-commit
1 parent e051448 commit 41a0325

12 files changed

Lines changed: 107 additions & 44 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,10 @@ repos:
5454
- -s
5555
- B101,B311,B614
5656

57-
- repo: local
57+
- repo: https://github.com/econchick/interrogate
58+
rev: 1.7.0
5859
hooks:
5960
- id: interrogate
60-
name: interrogate
61-
language: system
62-
entry: uv run interrogate
63-
types: [python]
6461
args:
6562
- --ignore-init-method
6663
- --ignore-init-module
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""OlmoEarth Pretrain v1 model package."""
22

3-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.olmoearth_pretrain_v1 import OlmoEarthPretrain_v1
3+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.olmoearth_pretrain_v1 import (
4+
OlmoEarthPretrain_v1,
5+
)
46

57
__all__ = ["OlmoEarthPretrain_v1"]

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/data/normalize.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import numpy as np
88

9-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import ModalitySpec
9+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
10+
ModalitySpec,
11+
)
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -18,7 +20,8 @@ def load_computed_config() -> dict[str, dict]:
1820
and std keys.
1921
"""
2022
with (
21-
files("olmoearth_pretrain_minimal.olmoearth_pretrain_v1.data.norm_configs") / "computed.json"
23+
files("olmoearth_pretrain_minimal.olmoearth_pretrain_v1.data.norm_configs")
24+
/ "computed.json"
2225
).open() as f:
2326
return json.load(f)
2427

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from einops import rearrange
1313
from torch import Tensor
1414

15-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import ModalitySpec
15+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
16+
ModalitySpec,
17+
)
1618

1719
logger = logging.getLogger(__name__)
1820

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,6 @@
1111
from torch import Tensor, nn
1212
from torch.distributed.fsdp import fully_shard
1313

14-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
15-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
16-
BASE_GSD,
17-
Modality,
18-
ModalitySpec,
19-
get_modality_specs_from_names,
20-
)
21-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample, MaskValue
2214
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.attention import Block
2315
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.encodings import (
2416
get_1d_sincos_pos_encoding,
@@ -29,8 +21,23 @@
2921
FlexiPatchEmbed,
3022
FlexiPatchReconstruction,
3123
)
32-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.tokenization import TokenizationConfig
33-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import get_cumulative_sequence_lengths
24+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.tokenization import (
25+
TokenizationConfig,
26+
)
27+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import (
28+
get_cumulative_sequence_lengths,
29+
)
30+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
31+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
32+
BASE_GSD,
33+
Modality,
34+
ModalitySpec,
35+
get_modality_specs_from_names,
36+
)
37+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import (
38+
MaskedOlmoEarthSample,
39+
MaskValue,
40+
)
3441

3542
logger = logging.getLogger(__name__)
3643

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414
register_fsdp_forward_method,
1515
)
1616

17-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
18-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample
1917
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import TokensAndMasks
20-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import DistributedMixins, unpack_encoder_output
18+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import (
19+
DistributedMixins,
20+
unpack_encoder_output,
21+
)
22+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
23+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import (
24+
MaskedOlmoEarthSample,
25+
)
2126

2227
logger = logging.getLogger(__name__)
2328

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424

2525
from dataclasses import dataclass, field
2626

27-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality, ModalitySpec
27+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
28+
Modality,
29+
ModalitySpec,
30+
)
2831

2932

3033
@dataclass

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99

1010
import torch
1111

12+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import (
13+
EncoderConfig,
14+
PredictorConfig,
15+
)
16+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.latent_mim import (
17+
LatentMIMConfig,
18+
)
1219
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
13-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import EncoderConfig, PredictorConfig
14-
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.latent_mim import LatentMIM, LatentMIMConfig
1520

1621
# Model size configurations matching the official OlmoEarth v1 models
1722
MODEL_SIZE_CONFIGS = {

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,17 @@ def _resolve_class(cls, class_name: str) -> type | None:
5050
# Map old package paths to new ones for compatibility
5151
# Handle both "helios" (old name) and "olmoearth_pretrain" package names
5252
if class_name.startswith("helios."):
53-
class_name = class_name.replace("helios.", "olmoearth_pretrain_minimal.olmoearth_pretrain_v1.", 1)
53+
class_name = class_name.replace(
54+
"helios.", "olmoearth_pretrain_minimal.olmoearth_pretrain_v1.", 1
55+
)
5456
# Fix common typos in config files
5557
class_name = class_name.replace("flexihelios", "flexi_vit")
5658
elif class_name.startswith("olmoearth_pretrain."):
57-
class_name = class_name.replace("olmoearth_pretrain.", "olmoearth_pretrain_minimal.olmoearth_pretrain_v1.", 1)
59+
class_name = class_name.replace(
60+
"olmoearth_pretrain.",
61+
"olmoearth_pretrain_minimal.olmoearth_pretrain_v1.",
62+
1,
63+
)
5864

5965
*modules, cls_name = class_name.split(".")
6066
module_name = ".".join(modules)
@@ -95,10 +101,20 @@ def _clean_data(cls, data: Any) -> Any:
95101
# Try to resolve as Config using from_dict
96102
if cls.CLASS_NAME_FIELD in value:
97103
nested_class_name = value[cls.CLASS_NAME_FIELD]
98-
nested_resolved_cls = cls._resolve_class(nested_class_name)
99-
if nested_resolved_cls is not None and is_dataclass(nested_resolved_cls):
100-
nested_dict = {k: v for k, v in value.items() if k != cls.CLASS_NAME_FIELD}
101-
valid_kwargs[key] = cast("type[_StandaloneConfig]", nested_resolved_cls).from_dict(nested_dict)
104+
nested_resolved_cls = cls._resolve_class(
105+
nested_class_name
106+
)
107+
if nested_resolved_cls is not None and is_dataclass(
108+
nested_resolved_cls
109+
):
110+
nested_dict = {
111+
k: v
112+
for k, v in value.items()
113+
if k != cls.CLASS_NAME_FIELD
114+
}
115+
valid_kwargs[key] = cast(
116+
"type[_StandaloneConfig]", nested_resolved_cls
117+
).from_dict(nested_dict)
102118
else:
103119
raise ValueError(
104120
f"Could not resolve nested config class '{nested_class_name}' for field '{key}'"
@@ -158,8 +174,15 @@ def from_dict(
158174
class_name = cleaned[cls.CLASS_NAME_FIELD]
159175
resolved_cls = cls._resolve_class(class_name)
160176
if resolved_cls is not None and is_dataclass(resolved_cls):
161-
config_dict = {k: v for k, v in cleaned.items() if k != cls.CLASS_NAME_FIELD}
162-
return cast("type[_StandaloneConfig]", resolved_cls).from_dict(config_dict)
177+
config_dict = {
178+
k: v for k, v in cleaned.items() if k != cls.CLASS_NAME_FIELD
179+
}
180+
return cast(
181+
C,
182+
cast("type[_StandaloneConfig]", resolved_cls).from_dict(
183+
config_dict
184+
),
185+
)
163186
else:
164187
raise ValueError(
165188
f"Could not resolve class '{class_name}' from _CLASS_ field. "

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from enum import Enum
6-
from typing import TYPE_CHECKING, Any, NamedTuple
6+
from typing import Any, NamedTuple
77

88
import torch
99

0 commit comments

Comments
 (0)