Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ TEST_IMAGE = $(shell beaker workspace images $(BEAKER_WORKSPACE) --format=json

.PHONY : run-checks
run-checks :
isort --check .
black --check .
ruff format --check .
ruff check .
mypy .
CUDA_VISIBLE_DEVICES='' pytest -v --color=yes tests/
Expand Down
6 changes: 3 additions & 3 deletions evaluation/steps/run_catwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def _instance_predictions_map_list(
instance_id = guess_instance_id(instance, idx=idx) # dict

if keep_instance_fields or keep_all_instance_fields_except:
assert (
keep_instance_fields is None or keep_all_instance_fields_except is None
), "Can't use both keep_instance_fields and keep_all_instance_fields_except"
assert keep_instance_fields is None or keep_all_instance_fields_except is None, (
"Can't use both keep_instance_fields and keep_all_instance_fields_except"
)
for field in instance:
if keep_instance_fields and field not in keep_instance_fields:
continue
Expand Down
42 changes: 21 additions & 21 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,9 @@ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Fut
tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
target_tensor = planner.resolve_tensor(read_item).detach()

assert (
target_tensor.size() == tensor.size()
), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
assert target_tensor.size() == tensor.size(), (
f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(read_item, target_tensor)

Expand Down Expand Up @@ -903,9 +903,9 @@ def save_checkpoint(
*,
upload_to: Optional[str] = None,
) -> None:
assert isinstance(
dist_model, FSDP
), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
assert isinstance(dist_model, FSDP), (
f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
)
with self._temporary_wd(dir) as checkpoint_dir:
# Save model and optim state.
save_fsdp_model_and_optim_state(
Expand Down Expand Up @@ -940,9 +940,9 @@ def restore_checkpoint(
) -> Dict[str, Any]:
# Load model and optimizer state in place.
log.info("Loading model and optimizer state...")
assert isinstance(
dist_model, FSDP
), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
assert isinstance(dist_model, FSDP), (
f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
)

load_fsdp_model_and_optim_state(
load_path,
Expand Down Expand Up @@ -987,9 +987,9 @@ def save_checkpoint(
*,
upload_to: Optional[str] = None,
) -> None:
assert isinstance(
dist_model, FSDP
), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
assert isinstance(dist_model, FSDP), (
f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
)
with self._temporary_wd(dir) as checkpoint_dir:
with FSDP.state_dict_type(
dist_model,
Expand Down Expand Up @@ -1022,9 +1022,9 @@ def restore_checkpoint(
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
assert isinstance(
dist_model, FSDP
), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
assert isinstance(dist_model, FSDP), (
f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
)
with FSDP.state_dict_type(
dist_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
Expand Down Expand Up @@ -1587,9 +1587,9 @@ def save_checkpoint(
*,
upload_to: Optional[str] = None,
) -> None:
assert isinstance(
dist_model, FSDP
), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
assert isinstance(dist_model, FSDP), (
f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
)

with self._temporary_wd(dir) as checkpoint_dir:
# Gather local FSDP flat params data to save.
Expand Down Expand Up @@ -1648,9 +1648,9 @@ def restore_checkpoint(

# Load local FSDP flat param data.
log.info("Loading local FSDP flat params data...")
assert isinstance(
dist_model, FSDP
), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
assert isinstance(dist_model, FSDP), (
f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
)

model_state = load_state_dict(
load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
Expand Down
12 changes: 6 additions & 6 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,13 +650,13 @@ class CustomDatasetCollatorConfig(BaseConfig):
@dataclass
class CustomDatasetConfig(BaseConfig):
name: str #: The name of the custom dataset class or function that will be used to load the dataset.
module: Optional[
str
] = None #: The module where the custom dataset class is defined. If not set, the module will be inferred from the class name.
module: Optional[str] = (
None #: The module where the custom dataset class is defined. If not set, the module will be inferred from the class name.
)
args: Optional[Dict[str, Any]] = None #: The arguments to pass to the custom dataset class or function
collate_fn: Optional[
str
] = None #: The name of the collate function to use for the custom dataset. Assumes the collate function is defined in the same module as the custom dataset class unless specified otherwise using the full object path.
collate_fn: Optional[str] = (
None #: The name of the collate function to use for the custom dataset. Assumes the collate function is defined in the same module as the custom dataset class unless specified otherwise using the full object path.
)
token_field: Optional[str] = None #: The field in the dataset items that contains the tokenized text.
collate_config: Optional[CustomDatasetCollatorConfig] = field(
default_factory=CustomDatasetCollatorConfig
Expand Down
6 changes: 3 additions & 3 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,9 +1528,9 @@ def prep_examples(self):
continue
if doc_id > max_doc_id:
max_doc_id = doc_id
assert (
request["request_type"] == "loglikelihood"
), f"Unsupported request type: {request['request_type']}"
assert request["request_type"] == "loglikelihood", (
f"Unsupported request type: {request['request_type']}"
)

# from EAI harness
# how this all works:
Expand Down
13 changes: 7 additions & 6 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,9 @@ def __init__(
def get_post_step_metrics(
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
) -> Dict[str, torch.Tensor]:
assert isinstance(
module, FSDP
), "`get_post_step_metrics` expects module to be FSDP and will not work with other `distributed_strategy`."
assert isinstance(module, FSDP), (
"`get_post_step_metrics` expects module to be FSDP and will not work with other `distributed_strategy`."
)

update_total_dot_prod = self._update_total_dot_prod
update_total_norm = self._update_total_norm
Expand Down Expand Up @@ -792,6 +792,7 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
@dataclass
class CosLinearEnvelope(Scheduler):
"Pointwise product of cosine schedule and linear decay; useful during annealing."

warmup_steps: int
alpha_f: float = 0.1
t_max: Optional[int] = None
Expand Down Expand Up @@ -874,9 +875,9 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
assert (
len(all_params.keys() - union_params) == 0
), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
assert len(all_params.keys() - union_params) == 0, (
f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
)

# Create the pytorch optimizer groups.
decay_sorted = sorted(list(decay))
Expand Down
8 changes: 4 additions & 4 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
or (step == latest_step and latest_checkpoint is not None and latest_checkpoint.endswith("-unsharded"))
):
latest_step = step
latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}"
latest_checkpoint = f"gs://{bucket_name}/{blob.name[: -len(suffix)]}"

return latest_checkpoint

Expand Down Expand Up @@ -710,7 +710,7 @@ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: i
try:
response = requests.get(
f"{scheme}://{host_name}/{path}",
headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"},
headers={"Range": f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"},
)
result = response.content
if len(result) == num_bytes:
Expand All @@ -719,7 +719,7 @@ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: i
log.warning(f"Expected {num_bytes} bytes, but got {len(result)}. Retrying...")

except requests.exceptions.RequestException as e:
log.warning(f"Attempt {attempt+1}/{max_retries}. Network error: {e}. Retrying...")
log.warning(f"Attempt {attempt + 1}/{max_retries}. Network error: {e}. Retrying...")
attempt += 1
time.sleep(2**attempt)
raise ValueError(
Expand Down Expand Up @@ -910,7 +910,7 @@ def get_resource(self, temp_file: io.BufferedWriter) -> None:

def get_bytes_range(self, index: int, length: int) -> bytes:
response = self.s3.get_object(
Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}"
Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index + length - 1}"
)
return response["Body"].read()

Expand Down
28 changes: 3 additions & 25 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ dependencies = [
dev = [
"ruff",
"mypy>=1.0,<1.4",
"black>=23.1,<24.0",
"isort>=5.12,<5.13",
"pytest",
"pytest-sphinx",
"twine>=1.11.0",
Expand Down Expand Up @@ -90,29 +88,6 @@ exclude = [
"inference.*",
]

[tool.black]
line-length = 115
include = '\.pyi?$'
exclude = '''
(
__pycache__
| \.git
| \.mypy_cache
| \.pytest_cache
| \.vscode
| \.venv
| \bdist\b
| \bdoc\b
| pretrain_data/
| inference/
)
'''

[tool.isort]
profile = "black"
multi_line_output = 3
extend_skip = ["pretrain_data", "tokenizer"]

[tool.ruff]
line-length = 115
lint.ignore = ["F403", "F405", "E501"]
Expand Down Expand Up @@ -145,6 +120,9 @@ exclude = [
[tool.ruff.lint.per-file-ignores]
"**/__init__.py" = ["F401"]

[tool.ruff.format]
# ruff format is black-compatible by default

[tool.pyright]
reportPrivateImportUsage = false
exclude = ["pretrain_data/", "tokenizer/"]
Expand Down
1 change: 1 addition & 0 deletions scripts/add_code_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Script to create perplexity eval datasets for code.
"""

import os

import pandas as pd
Expand Down
2 changes: 1 addition & 1 deletion scripts/compare_module_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

This script is useful for identifying where model activations start to differ
within 2 forward passes that should yield identical results. In turn, detecting
regressions can be a lot quicker/easier.
regressions can be a lot quicker/easier.

This script requires that traces containing submodule outputs have been collected
during training. The traces can be saved using
Expand Down
2 changes: 1 addition & 1 deletion scripts/compare_wandb_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

Comparing Peteish7 to Amberish7
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmo-medium/runs/ij4ls6v2


"""

Expand Down
2 changes: 1 addition & 1 deletion scripts/flops_by_perf_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Zamba-2-7B,,65.2,92.2,89.4,79.6,68.5,51.7,36.5,55.5,67.2,32.8,78.8

Invocation looks like:

python scripts/flops_by_perf_figure.py /path/to/results.csv output/

@kyleclo, @soldni
Expand Down
1 change: 1 addition & 0 deletions scripts/init_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Run this to initialize a new training config to a file.
"""

import logging
import sys
from pathlib import Path
Expand Down
Loading