-
Notifications
You must be signed in to change notification settings - Fork 2
Add configuration resolution and utility functions #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
c744d70
b93d136
ad18345
137e2f7
d9f9283
bf2a3a6
5aa7887
dd12490
4ede1d1
886b754
b63dbc5
2844fa6
bb00983
4416f0d
5b210c7
85691a1
12f52f7
ab0da89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt") -> Path: | ||
| """Construct checkpoint path following the standard naming convention. | ||
|
|
||
| Args: | ||
| directory: Base directory for checkpoints | ||
| step: Training step number (must be non-negative) | ||
| filename: Checkpoint filename (default: "model.pt") | ||
|
|
||
| Returns: | ||
| Path to checkpoint file: {directory}/{step}/{filename} | ||
|
|
||
| Raises: | ||
| ValueError: If step is negative | ||
|
|
||
| Examples: | ||
| >>> get_checkpoint_path(Path("checkpoints"), 12345) | ||
| PosixPath('checkpoints/12345/model.pt') | ||
| >>> get_checkpoint_path(Path("weights"), 100, "state.pt") | ||
| PosixPath('weights/100/state.pt') | ||
| """ | ||
| if step < 0: | ||
| raise ValueError(f"Step must be non-negative, got {step}") | ||
| return directory / str(step) / filename | ||
|
|
||
|
|
||
| def parse_checkpoint_step(path: str) -> int | None: | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Extract training step number from checkpoint path. | ||
|
|
||
| Handles the format: {step}/model.pt or {step}/{filename} | ||
|
|
||
| Args: | ||
| path: File path or S3 key containing checkpoint | ||
|
|
||
| Returns: | ||
| Step number if found, None otherwise | ||
|
|
||
| Examples: | ||
| >>> parse_checkpoint_step("checkpoints/12345/model.pt") | ||
| 12345 | ||
| >>> parse_checkpoint_step("12345/model.pt") | ||
| 12345 | ||
| """ | ||
| parts = path.split("/") | ||
| if len(parts) >= 2 and parts[-1].endswith((".pt", ".eqx")): | ||
|
||
| try: | ||
| return int(parts[-2]) | ||
| except ValueError: | ||
| pass | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def format_step_number(step: int, max_steps: int) -> str: | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Format step number with appropriate zero-padding. | ||
|
|
||
| Args: | ||
| step: Current training step | ||
| max_steps: Maximum number of training steps | ||
|
|
||
| Returns: | ||
| Zero-padded step string | ||
|
|
||
| Raises: | ||
| ValueError: If step is not between 0 and max_steps | ||
|
|
||
| Examples: | ||
| >>> format_step_number(42, max_steps=100000) | ||
| '000042' | ||
| >>> format_step_number(999, max_steps=999) | ||
| '999' | ||
| """ | ||
| if not 0 <= step <= max_steps: | ||
| raise ValueError(f"Step {step} must be between 0 and {max_steps}") | ||
| width = len(str(max_steps)) | ||
| return f"{step:0{width}d}" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: bool = False) -> int: | ||
ealt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Compute the generator's sequence length from model context length and special token usage. | ||
|
|
||
| The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS | ||
|
|
||
| Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS - EOS | ||
|
|
||
| Args: | ||
| model_n_ctx: The model's context length (number of input positions it processes) | ||
| use_bos: Whether a beginning-of-sequence token is prepended during data generation | ||
| use_eos: Whether an end-of-sequence token is appended during data generation | ||
|
|
||
| Returns: | ||
| The sequence length to configure for the data generator | ||
|
|
||
| Raises: | ||
| ValueError: If the resulting generator sequence length would be non-positive | ||
|
|
||
| Examples: | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False) | ||
| 512 | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=False) | ||
| 513 | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) | ||
| 511 | ||
| """ | ||
| assert model_n_ctx > 0, f"model_n_ctx must be positive, got {model_n_ctx}" | ||
|
|
||
| result = model_n_ctx + 1 - int(use_bos) - int(use_eos) | ||
| if result <= 0: | ||
| raise ValueError( | ||
| f"Invalid configuration: model_n_ctx={model_n_ctx}, use_bos={use_bos}, use_eos={use_eos} " | ||
| f"results in non-positive generator sequence length ({result})" | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: bool = False) -> int: | ||
ealt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Compute the model's context length from generator sequence length and special token usage. | ||
|
|
||
| The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS | ||
|
|
||
| Args: | ||
| generator_seq_len: The sequence length configured for the data generator | ||
| use_bos: Whether a beginning-of-sequence token is prepended during data generation | ||
| use_eos: Whether an end-of-sequence token is appended during data generation | ||
|
|
||
| Returns: | ||
| The context length for the model (number of input positions it will process) | ||
|
|
||
| Raises: | ||
| ValueError: If the resulting model context length would be non-positive | ||
|
|
||
| Examples: | ||
| >>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False) | ||
| 512 | ||
| >>> compute_model_context_length(generator_seq_len=513, use_bos=False, use_eos=False) | ||
| 512 | ||
| >>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) | ||
| 512 | ||
| """ | ||
| assert generator_seq_len > 0, f"generator_seq_len must be positive, got {generator_seq_len}" | ||
|
|
||
| result = generator_seq_len - 1 + int(use_bos) + int(use_eos) | ||
| if result <= 0: | ||
| raise ValueError( | ||
| f"Invalid configuration: generator_seq_len={generator_seq_len}, use_bos={use_bos}, use_eos={use_eos} " | ||
| f"results in non-positive model context length ({result})" | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool = False) -> int: | ||
|
||
| """Compute the model's vocabulary size from generator vocab and special tokens. | ||
|
|
||
| When BOS or EOS tokens are used during data generation, they are added to the vocabulary, | ||
| increasing the total vocab size the model needs to handle. | ||
|
|
||
| Args: | ||
| generator_vocab_size: The vocabulary size of the data generator | ||
| use_bos: Whether a beginning-of-sequence token is used during data generation | ||
| use_eos: Whether an end-of-sequence token is used during data generation | ||
|
|
||
| Returns: | ||
| The vocabulary size the model should be configured with | ||
|
|
||
| Raises: | ||
| ValueError: If generator_vocab_size is non-positive | ||
|
|
||
| Examples: | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) | ||
| 101 | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True) | ||
| 102 | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) | ||
| 100 | ||
| """ | ||
| assert generator_vocab_size > 0, f"generator_vocab_size must be positive, got {generator_vocab_size}" | ||
| return generator_vocab_size + int(use_bos) + int(use_eos) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make use of
format_step_numberhere so sorted checkpoint paths can naturally follow the ordering of the step number. Take an optionalmax_stepsargument after filename and useformat_step_numberif themax_stepsargument is supplied