Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions simplexity/persistence/local_pytorch_persister.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from simplexity.persistence.local_persister import LocalPersister
from simplexity.persistence.utils import get_checkpoint_path

try:
import torch
Expand All @@ -20,7 +21,7 @@ def __init__(self, directory: str | Path, filename: str = "model.pt"):
# TODO: This is a hack to get the type checker to work.
def save_weights(self, model: torch.nn.Module, step: int = 0, overwrite_existing: bool = False) -> None: # type: ignore
"""Saves a PyTorch model to the local filesystem."""
path = self._get_path(step)
path = get_checkpoint_path(self.directory, step, self.filename)
path.parent.mkdir(parents=True, exist_ok=True)

if overwrite_existing and path.exists():
Expand All @@ -31,11 +32,8 @@ def save_weights(self, model: torch.nn.Module, step: int = 0, overwrite_existing
# TODO: This is a hack to get the type checker to work.
def load_weights(self, model: torch.nn.Module, step: int = 0) -> torch.nn.Module: # type: ignore
"""Loads weights into a PyTorch model from the local filesystem."""
path = self._get_path(step)
path = get_checkpoint_path(self.directory, step, self.filename)
device = next(model.parameters()).device if list(model.parameters()) else "cpu"
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
return model

def _get_path(self, step: int) -> Path:
return self.directory / str(step) / self.filename
69 changes: 69 additions & 0 deletions simplexity/persistence/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from pathlib import Path


def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt") -> Path:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make use of format_step_number here so sorted checkpoint paths can naturally follow the ordering of the step number. Take an optional max_steps argument after filename and use format_step_number if the max_steps argument is supplied

"""Construct checkpoint path following the standard naming convention.

Args:
directory: Base directory for checkpoints
step: Training step number
filename: Checkpoint filename (default: "model.pt")

Returns:
Path to checkpoint file: {directory}/{step}/{filename}

Examples:
>>> get_checkpoint_path(Path("checkpoints"), 12345)
PosixPath('checkpoints/12345/model.pt')
>>> get_checkpoint_path(Path("weights"), 100, "state.pt")
PosixPath('weights/100/state.pt')
"""
return directory / str(step) / filename


def parse_checkpoint_step(path: str) -> int | None:
"""Extract training step number from checkpoint path.

Handles the format: {step}/model.pt or {step}/{filename}

Args:
path: File path or S3 key containing checkpoint

Returns:
Step number if found, None otherwise

Examples:
>>> parse_checkpoint_step("checkpoints/12345/model.pt")
12345
>>> parse_checkpoint_step("12345/model.pt")
12345
"""
parts = path.split("/")
if len(parts) >= 2 and parts[-1].endswith(".pt"):
try:
return int(parts[-2])
except ValueError:
pass

return None


def format_step_number(step: int, max_steps: int) -> str:
"""Format step number with appropriate zero-padding.

Args:
step: Current training step
max_steps: Maximum number of training steps

Returns:
Zero-padded step string

Examples:
>>> format_step_number(42, max_steps=100000)
'000042'
>>> format_step_number(999, max_steps=999)
'999'
"""
assert 0 <= step <= max_steps, f"Step {step} must be between 0 and {max_steps}"
width = len(str(max_steps))
return f"{step:0{width}d}"
73 changes: 73 additions & 0 deletions simplexity/utils/config_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: bool = False) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give use_bos a default value of False and make both booleans keyword-only arguments

"""Compute the generator's sequence length from model context length and special token usage.

The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS

Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS - EOS

Args:
model_n_ctx: The model's context length (number of input positions it processes)
use_bos: Whether a beginning-of-sequence token is prepended during data generation
use_eos: Whether an end-of-sequence token is appended during data generation

Returns:
The sequence length to configure for the data generator

Examples:
>>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False)
512
>>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=False)
513
>>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True)
511
"""
return model_n_ctx + 1 - int(use_bos) - int(use_eos)


def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: bool = False) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give use_bos a default value of False and make both booleans keyword-only arguments

"""Compute the model's context length from generator sequence length and special token usage.

The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS

Args:
generator_seq_len: The sequence length configured for the data generator
use_bos: Whether a beginning-of-sequence token is prepended during data generation
use_eos: Whether an end-of-sequence token is appended during data generation

Returns:
The context length for the model (number of input positions it will process)

Examples:
>>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False)
512
>>> compute_model_context_length(generator_seq_len=513, use_bos=False, use_eos=False)
512
>>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True)
512
"""
return generator_seq_len - 1 + int(use_bos) + int(use_eos)


def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we assume generator_vocab_size > 0, we should assert that

"""Compute the model's vocabulary size from generator vocab and special tokens.

When BOS or EOS tokens are used during data generation, they are added to the vocabulary,
increasing the total vocab size the model needs to handle.

Args:
generator_vocab_size: The vocabulary size of the data generator
use_bos: Whether a beginning-of-sequence token is used during data generation
use_eos: Whether an end-of-sequence token is used during data generation

Returns:
The vocabulary size the model should be configured with

Examples:
>>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False)
101
>>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True)
102
>>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False)
100
"""
return generator_vocab_size + int(use_bos) + int(use_eos)
39 changes: 39 additions & 0 deletions simplexity/utils/jnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,45 @@
import jax.numpy as jnp


def resolve_jax_device(device_spec: str | None = "auto") -> jax.Device:
"""Resolve device specification to actual JAX device.

Args:
device_spec: One of "auto", "gpu", "cuda", "cpu", or None (treated as "auto")

Returns:
JAX device object

Examples:
>>> resolve_jax_device("auto") # On GPU machine
GpuDevice(id=0, ...)
>>> resolve_jax_device("cpu")
CpuDevice(id=0)
"""
if device_spec is None or device_spec == "auto":
try:
devices = jax.devices("gpu")
if devices:
return devices[0]
except RuntimeError:
pass
return jax.devices("cpu")[0]

if device_spec in ("gpu", "cuda"):
try:
devices = jax.devices("gpu")
if devices:
return devices[0]
except RuntimeError:
pass
raise RuntimeError("GPU requested but no GPU devices available")

if device_spec == "cpu":
return jax.devices("cpu")[0]

raise ValueError(f"Unknown device specification: {device_spec}")


@eqx.filter_jit
def entropy(probs: jax.Array, log: bool = False) -> jax.Array:
"""Compute the entropy of a log probability distribution."""
Expand Down
43 changes: 43 additions & 0 deletions simplexity/utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,46 @@ def torch_to_jax(torch_tensor: torch.Tensor) -> jax.Array:
numpy_array = torch_tensor.detach().cpu().numpy()
jax_array = jnp.array(numpy_array)
return jax_array


def resolve_device(device_spec: str | None = "auto") -> str:
"""Resolve device specification to actual PyTorch device string.

Args:
device_spec: One of "auto", "cuda", "mps", "cpu", or None (treated as "auto")

Returns:
Resolved device string: "cuda", "mps", or "cpu"

Raises:
ValueError: If device_spec is not a recognized device type
RuntimeError: If a specific device is requested but unavailable

Examples:
>>> resolve_device("auto") # On CUDA machine
'cuda'
>>> resolve_device("cpu")
'cpu'
"""
if device_spec is None or device_spec == "auto":
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"

if device_spec == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("CUDA requested but CUDA is not available")
return "cuda"

if device_spec == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError("MPS requested but MPS is not available")
return "mps"

if device_spec == "cpu":
return "cpu"

raise ValueError(f"Unknown device specification: {device_spec}")
117 changes: 117 additions & 0 deletions tests/persistence/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from pathlib import Path

import pytest

from simplexity.persistence.utils import format_step_number, get_checkpoint_path, parse_checkpoint_step


class TestParseCheckpointStep:
"""Test parse_checkpoint_step function."""

@pytest.mark.parametrize(
("path", "expected"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give some examples where the filename isn't model.pt

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, give examples with zero padding

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, try to keep the number of test cases to a minimum

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need 9 test cases, think about what important features each test case has and consolidate to a minimum set of test cases that covers all important features

[
("12345/model.pt", 12345),
("checkpoints/12345/model.pt", 12345),
("path/to/500/model.pt", 500),
("0/model.pt", 0),
("prefix/run_name/12345/model.pt", 12345),
],
)
def test_directory_model_format(self, path: str, expected: int):
"""Test parsing {step}/model.pt format."""
assert parse_checkpoint_step(path) == expected

@pytest.mark.parametrize(
"path",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give some examples where there is number in the path, but the filename is not valid

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, try to keep the number of test cases to a minimum

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need 7 test cases, think about what important features each test case has and consolidate to a minimum set of test cases that covers all important features

[
"model.pt",
"checkpoint.pt",
"weights/model.eqx",
"random_file.txt",
"nonumeric/model.pt",
],
)
def test_no_match_returns_none(self, path: str):
"""Test paths that should not match any pattern."""
assert parse_checkpoint_step(path) is None

def test_zero_padded_step_numbers(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be covered in previous test

"""Test that zero-padded step numbers are correctly parsed."""
assert parse_checkpoint_step("0000/model.pt") == 0


class TestGetCheckpointPath:
"""Test get_checkpoint_path function."""

def test_basic_path_construction(self):
"""Test basic checkpoint path construction."""
path = get_checkpoint_path(Path("checkpoints"), 12345)
assert path == Path("checkpoints/12345/model.pt")

def test_custom_filename(self):
"""Test with custom filename."""
path = get_checkpoint_path(Path("weights"), 100, "state.pt")
assert path == Path("weights/100/state.pt")

@pytest.mark.parametrize(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use an example filename with zero padding in the step number, and consolidate the total number of test cases

("directory", "step", "filename", "expected"),
[
(Path("checkpoints"), 0, "model.pt", Path("checkpoints/0/model.pt")),
(Path("runs/exp1"), 1000, "checkpoint.pt", Path("runs/exp1/1000/checkpoint.pt")),
(Path("."), 42, "model.pt", Path("42/model.pt")),
],
)
def test_parametrized_paths(self, directory: Path, step: int, filename: str, expected: Path):
"""Test various path combinations."""
assert get_checkpoint_path(directory, step, filename) == expected


class TestFormatStepNumber:
"""Test format_step_number function."""

def test_basic_formatting(self):
"""Test basic zero-padding behavior."""
assert format_step_number(42, max_steps=100) == "042"
assert format_step_number(5, max_steps=1000) == "0005"

def test_no_padding_needed(self):
"""Test when step already has maximum width."""
assert format_step_number(999, max_steps=999) == "999"
assert format_step_number(100, max_steps=100) == "100"

def test_zero_step(self):
"""Test formatting step 0."""
assert format_step_number(0, max_steps=100) == "000"
assert format_step_number(0, max_steps=10000) == "00000"

@pytest.mark.parametrize(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too many test cases, please consolidate

("step", "max_steps", "expected"),
[
(0, 999, "000"),
(1, 999, "001"),
(42, 999, "042"),
(999, 999, "999"),
(0, 100000, "000000"),
(42, 100000, "000042"),
(12345, 100000, "012345"),
(100000, 100000, "100000"),
],
)
def test_parametrized_formatting(self, step: int, max_steps: int, expected: str):
"""Test various step and max_steps combinations."""
assert format_step_number(step, max_steps) == expected

def test_lexicographic_ordering(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove

"""Verify that formatted strings sort lexicographically."""
max_steps = 10000
formatted = [format_step_number(i, max_steps) for i in [1, 10, 100, 1000, 9999]]
assert formatted == sorted(formatted)

def test_width_computation(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not useful, remove

"""Verify format_step_number computes width correctly."""
max_steps = 100000
step = 42
formatted = format_step_number(step, max_steps)
expected_width = len(str(max_steps))
assert len(formatted) == expected_width
Loading