Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions simplexity/persistence/local_pytorch_persister.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from simplexity.persistence.local_persister import LocalPersister
from simplexity.persistence.utils import get_checkpoint_path

try:
import torch
Expand All @@ -20,7 +21,7 @@ def __init__(self, directory: str | Path, filename: str = "model.pt"):
# TODO: This is a hack to get the type checker to work.
def save_weights(self, model: torch.nn.Module, step: int = 0, overwrite_existing: bool = False) -> None: # type: ignore
"""Saves a PyTorch model to the local filesystem."""
path = self._get_path(step)
path = get_checkpoint_path(self.directory, step, self.filename)
path.parent.mkdir(parents=True, exist_ok=True)

if overwrite_existing and path.exists():
Expand All @@ -31,11 +32,8 @@ def save_weights(self, model: torch.nn.Module, step: int = 0, overwrite_existing
# TODO: This is a hack to get the type checker to work.
def load_weights(self, model: torch.nn.Module, step: int = 0) -> torch.nn.Module: # type: ignore
"""Loads weights into a PyTorch model from the local filesystem."""
path = self._get_path(step)
path = get_checkpoint_path(self.directory, step, self.filename)
device = next(model.parameters()).device if list(model.parameters()) else "cpu"
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
return model

def _get_path(self, step: int) -> Path:
return self.directory / str(step) / self.filename
113 changes: 113 additions & 0 deletions simplexity/persistence/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from pathlib import Path

SUPPORTED_EXTENSIONS = (".pt", ".eqx", ".pkl", ".ckpt", ".pth")


def _is_valid_checkpoint_filename(filename: str) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Among the set of all strings, there are more constrains on a valid filename than simply ending with an acceptable extension

"""Check if filename is a valid checkpoint filename with supported extension.

Args:
filename: The checkpoint filename to validate

Returns:
True if filename has a supported extension, False otherwise

Examples:
>>> _is_valid_checkpoint_filename("model.pt")
True
>>> _is_valid_checkpoint_filename("state.eqx")
True
>>> _is_valid_checkpoint_filename("invalid.txt")
False
"""
return filename.endswith(SUPPORTED_EXTENSIONS)


def get_checkpoint_path(
directory: Path, step: int, filename: str = "model.pt", max_steps: int | None = None
) -> Path:
"""Construct checkpoint path following the standard naming convention.

Args:
directory: Base directory for checkpoints
step: Training step number (must be non-negative)
filename: Checkpoint filename (default: "model.pt")
max_steps: Maximum number of training steps. If provided, step will be zero-padded

Returns:
Path to checkpoint file: {directory}/{step}/{filename}

Raises:
ValueError: If step is negative or filename has unsupported extension

Examples:
>>> get_checkpoint_path(Path("checkpoints"), 12345)
PosixPath('checkpoints/12345/model.pt')
>>> get_checkpoint_path(Path("weights"), 100, "state.pt")
PosixPath('weights/100/state.pt')
>>> get_checkpoint_path(Path("checkpoints"), 42, max_steps=100000)
PosixPath('checkpoints/000042/model.pt')
"""
if step < 0:
raise ValueError(f"Step must be non-negative, got {step}")
if not _is_valid_checkpoint_filename(filename):
raise ValueError(f"Filename must have one of these extensions: {SUPPORTED_EXTENSIONS}, got {filename}")

if max_steps is not None:
step_str = format_step_number(step, max_steps)
else:
step_str = str(step)

return directory / step_str / filename


def parse_checkpoint_step(path: str) -> int | None:
"""Extract training step number from checkpoint path.

Handles the format: {step}/model.pt or {step}/{filename}

Args:
path: File path or S3 key containing checkpoint

Returns:
Step number if found, None otherwise

Examples:
>>> parse_checkpoint_step("checkpoints/12345/model.pt")
12345
>>> parse_checkpoint_step("12345/model.pt")
12345
"""
parts = path.split("/")
if len(parts) >= 2 and _is_valid_checkpoint_filename(parts[-1]):
try:
return int(parts[-2])
except ValueError:
pass

return None


def format_step_number(step: int, max_steps: int) -> str:
"""Format step number with appropriate zero-padding.

Args:
step: Current training step
max_steps: Maximum number of training steps

Returns:
Zero-padded step string

Raises:
ValueError: If step is not between 0 and max_steps

Examples:
>>> format_step_number(42, max_steps=100000)
'000042'
>>> format_step_number(999, max_steps=999)
'999'
"""
if not 0 <= step <= max_steps:
raise ValueError(f"Step {step} must be between 0 and {max_steps}")
width = len(str(max_steps))
return f"{step:0{width}d}"
99 changes: 99 additions & 0 deletions simplexity/utils/config_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
def compute_generator_sequence_length(model_n_ctx: int, *, use_bos: bool = False, use_eos: bool = False) -> int:
"""Compute the generator's sequence length from model context length and special token usage.

The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS

Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS - EOS

Args:
model_n_ctx: The model's context length (number of input positions it processes)
use_bos: Whether a beginning-of-sequence token is prepended during data generation
use_eos: Whether an end-of-sequence token is appended during data generation

Returns:
The sequence length to configure for the data generator

Raises:
ValueError: If the resulting generator sequence length would be non-positive

Examples:
>>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False)
512
>>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=False)
513
>>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True)
511
"""
assert model_n_ctx > 0, f"model_n_ctx must be positive, got {model_n_ctx}"

result = model_n_ctx + 1 - int(use_bos) - int(use_eos)
if result <= 0:
raise ValueError(
f"Invalid configuration: model_n_ctx={model_n_ctx}, use_bos={use_bos}, use_eos={use_eos} "
f"results in non-positive generator sequence length ({result})"
)
return result


def compute_model_context_length(generator_seq_len: int, *, use_bos: bool = False, use_eos: bool = False) -> int:
"""Compute the model's context length from generator sequence length and special token usage.

The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS

Args:
generator_seq_len: The sequence length configured for the data generator
use_bos: Whether a beginning-of-sequence token is prepended during data generation
use_eos: Whether an end-of-sequence token is appended during data generation

Returns:
The context length for the model (number of input positions it will process)

Raises:
ValueError: If the resulting model context length would be non-positive

Examples:
>>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False)
512
>>> compute_model_context_length(generator_seq_len=513, use_bos=False, use_eos=False)
512
>>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True)
512
"""
assert generator_seq_len > 0, f"generator_seq_len must be positive, got {generator_seq_len}"

result = generator_seq_len - 1 + int(use_bos) + int(use_eos)
if result <= 0:
raise ValueError(
f"Invalid configuration: generator_seq_len={generator_seq_len}, use_bos={use_bos}, use_eos={use_eos} "
f"results in non-positive model context length ({result})"
)
return result


def compute_model_vocab_size(generator_vocab_size: int, *, use_bos: bool = False, use_eos: bool = False) -> int:
"""Compute the model's vocabulary size from generator vocab and special tokens.

When BOS or EOS tokens are used during data generation, they are added to the vocabulary,
increasing the total vocab size the model needs to handle.

Args:
generator_vocab_size: The vocabulary size of the data generator
use_bos: Whether a beginning-of-sequence token is used during data generation
use_eos: Whether an end-of-sequence token is used during data generation

Returns:
The vocabulary size the model should be configured with

Raises:
ValueError: If generator_vocab_size is non-positive

Examples:
>>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False)
101
>>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True)
102
>>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False)
100
"""
assert generator_vocab_size > 0, f"generator_vocab_size must be positive, got {generator_vocab_size}"
return generator_vocab_size + int(use_bos) + int(use_eos)
39 changes: 39 additions & 0 deletions simplexity/utils/jnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,45 @@
import jax.numpy as jnp


def resolve_jax_device(device_spec: str | None = "auto") -> jax.Device:
"""Resolve device specification to actual JAX device.

Args:
device_spec: One of "auto", "gpu", "cuda", "cpu", or None (treated as "auto")

Returns:
JAX device object

Examples:
>>> resolve_jax_device("auto") # On GPU machine
GpuDevice(id=0, ...)
>>> resolve_jax_device("cpu")
CpuDevice(id=0)
"""
if device_spec is None or device_spec == "auto":
try:
devices = jax.devices("gpu")
if devices:
return devices[0]
except RuntimeError:
pass
return jax.devices("cpu")[0]

if device_spec in ("gpu", "cuda"):
try:
devices = jax.devices("gpu")
if devices:
return devices[0]
except RuntimeError:
pass
raise RuntimeError("GPU requested but no GPU devices available")

if device_spec == "cpu":
return jax.devices("cpu")[0]

raise ValueError(f"Unknown device specification: {device_spec}")


@eqx.filter_jit
def entropy(probs: jax.Array, log: bool = False) -> jax.Array:
"""Compute the entropy of a log probability distribution."""
Expand Down
43 changes: 43 additions & 0 deletions simplexity/utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,46 @@ def torch_to_jax(torch_tensor: torch.Tensor) -> jax.Array:
numpy_array = torch_tensor.detach().cpu().numpy()
jax_array = jnp.array(numpy_array)
return jax_array


def resolve_device(device_spec: str | None = "auto") -> str:
"""Resolve device specification to actual PyTorch device string.

Args:
device_spec: One of "auto", "cuda", "mps", "cpu", or None (treated as "auto")

Returns:
Resolved device string: "cuda", "mps", or "cpu"

Raises:
ValueError: If device_spec is not a recognized device type
RuntimeError: If a specific device is requested but unavailable

Examples:
>>> resolve_device("auto") # On CUDA machine
'cuda'
>>> resolve_device("cpu")
'cpu'
"""
if device_spec is None or device_spec == "auto":
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"

if device_spec == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("CUDA requested but CUDA is not available")
return "cuda"

if device_spec == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError("MPS requested but MPS is not available")
return "mps"

if device_spec == "cpu":
return "cpu"

raise ValueError(f"Unknown device specification: {device_spec}")
Loading
Loading