-
Notifications
You must be signed in to change notification settings - Fork 2
Add configuration resolution and utility functions #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
c744d70
b93d136
ad18345
137e2f7
d9f9283
bf2a3a6
5aa7887
dd12490
4ede1d1
886b754
b63dbc5
2844fa6
bb00983
4416f0d
5b210c7
85691a1
12f52f7
ab0da89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt") -> Path: | ||
| """Construct checkpoint path following the standard naming convention. | ||
|
|
||
| Args: | ||
| directory: Base directory for checkpoints | ||
| step: Training step number | ||
| filename: Checkpoint filename (default: "model.pt") | ||
|
|
||
| Returns: | ||
| Path to checkpoint file: {directory}/{step}/{filename} | ||
|
|
||
| Examples: | ||
| >>> get_checkpoint_path(Path("checkpoints"), 12345) | ||
| PosixPath('checkpoints/12345/model.pt') | ||
| >>> get_checkpoint_path(Path("weights"), 100, "state.pt") | ||
| PosixPath('weights/100/state.pt') | ||
| """ | ||
| return directory / str(step) / filename | ||
|
|
||
|
|
||
| def parse_checkpoint_step(path: str) -> int | None: | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Extract training step number from checkpoint path. | ||
|
|
||
| Handles the format: {step}/model.pt or {step}/{filename} | ||
|
|
||
| Args: | ||
| path: File path or S3 key containing checkpoint | ||
|
|
||
| Returns: | ||
| Step number if found, None otherwise | ||
|
|
||
| Examples: | ||
| >>> parse_checkpoint_step("checkpoints/12345/model.pt") | ||
| 12345 | ||
| >>> parse_checkpoint_step("12345/model.pt") | ||
| 12345 | ||
| """ | ||
| parts = path.split("/") | ||
| if len(parts) >= 2 and parts[-1].endswith(".pt"): | ||
| try: | ||
| return int(parts[-2]) | ||
| except ValueError: | ||
| pass | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def format_step_number(step: int, max_steps: int) -> str: | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Format step number with appropriate zero-padding. | ||
|
|
||
| Args: | ||
| step: Current training step | ||
| max_steps: Maximum number of training steps | ||
|
|
||
| Returns: | ||
| Zero-padded step string | ||
|
|
||
| Examples: | ||
| >>> format_step_number(42, max_steps=100000) | ||
| '000042' | ||
| >>> format_step_number(999, max_steps=999) | ||
| '999' | ||
| """ | ||
| assert 0 <= step <= max_steps, f"Step {step} must be between 0 and {max_steps}" | ||
| width = len(str(max_steps)) | ||
| return f"{step:0{width}d}" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: bool = False) -> int: | ||
ealt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Compute the generator's sequence length from model context length and special token usage. | ||
|
|
||
| The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS | ||
|
|
||
| Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS - EOS | ||
|
|
||
| Args: | ||
| model_n_ctx: The model's context length (number of input positions it processes) | ||
| use_bos: Whether a beginning-of-sequence token is prepended during data generation | ||
| use_eos: Whether an end-of-sequence token is appended during data generation | ||
|
|
||
| Returns: | ||
| The sequence length to configure for the data generator | ||
|
|
||
| Examples: | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False) | ||
| 512 | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=False) | ||
| 513 | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) | ||
| 511 | ||
| """ | ||
| return model_n_ctx + 1 - int(use_bos) - int(use_eos) | ||
|
|
||
|
|
||
| def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: bool = False) -> int: | ||
ealt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Compute the model's context length from generator sequence length and special token usage. | ||
|
|
||
| The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS | ||
|
|
||
| Args: | ||
| generator_seq_len: The sequence length configured for the data generator | ||
| use_bos: Whether a beginning-of-sequence token is prepended during data generation | ||
| use_eos: Whether an end-of-sequence token is appended during data generation | ||
|
|
||
| Returns: | ||
| The context length for the model (number of input positions it will process) | ||
|
|
||
| Examples: | ||
| >>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False) | ||
| 512 | ||
| >>> compute_model_context_length(generator_seq_len=513, use_bos=False, use_eos=False) | ||
| 512 | ||
| >>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) | ||
| 512 | ||
| """ | ||
| return generator_seq_len - 1 + int(use_bos) + int(use_eos) | ||
|
|
||
|
|
||
| def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool) -> int: | ||
|
||
| """Compute the model's vocabulary size from generator vocab and special tokens. | ||
|
|
||
| When BOS or EOS tokens are used during data generation, they are added to the vocabulary, | ||
| increasing the total vocab size the model needs to handle. | ||
|
|
||
| Args: | ||
| generator_vocab_size: The vocabulary size of the data generator | ||
| use_bos: Whether a beginning-of-sequence token is used during data generation | ||
| use_eos: Whether an end-of-sequence token is used during data generation | ||
|
|
||
| Returns: | ||
| The vocabulary size the model should be configured with | ||
|
|
||
| Examples: | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) | ||
| 101 | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True) | ||
| 102 | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) | ||
| 100 | ||
| """ | ||
| return generator_vocab_size + int(use_bos) + int(use_eos) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| from simplexity.persistence.utils import format_step_number, get_checkpoint_path, parse_checkpoint_step | ||
|
|
||
|
|
||
| class TestParseCheckpointStep: | ||
| """Test parse_checkpoint_step function.""" | ||
|
|
||
| @pytest.mark.parametrize( | ||
| ("path", "expected"), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. give some examples where the filename isn't
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, give examples with zero padding
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, try to keep the number of test cases to a minimum
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need 9 test cases, think about what important features each test case has and consolidate to a minimum set of test cases that covers all important features |
||
| [ | ||
| ("12345/model.pt", 12345), | ||
| ("checkpoints/12345/model.pt", 12345), | ||
| ("path/to/500/model.pt", 500), | ||
| ("0/model.pt", 0), | ||
| ("prefix/run_name/12345/model.pt", 12345), | ||
| ], | ||
| ) | ||
| def test_directory_model_format(self, path: str, expected: int): | ||
| """Test parsing {step}/model.pt format.""" | ||
| assert parse_checkpoint_step(path) == expected | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "path", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. give some examples where there is number in the path, but the filename is not valid
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, try to keep the number of test cases to a minimum
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need 7 test cases, think about what important features each test case has and consolidate to a minimum set of test cases that covers all important features |
||
| [ | ||
| "model.pt", | ||
| "checkpoint.pt", | ||
| "weights/model.eqx", | ||
| "random_file.txt", | ||
| "nonumeric/model.pt", | ||
| ], | ||
| ) | ||
| def test_no_match_returns_none(self, path: str): | ||
| """Test paths that should not match any pattern.""" | ||
| assert parse_checkpoint_step(path) is None | ||
|
|
||
| def test_zero_padded_step_numbers(self): | ||
|
||
| """Test that zero-padded step numbers are correctly parsed.""" | ||
| assert parse_checkpoint_step("0000/model.pt") == 0 | ||
|
|
||
|
|
||
| class TestGetCheckpointPath: | ||
| """Test get_checkpoint_path function.""" | ||
|
|
||
| def test_basic_path_construction(self): | ||
| """Test basic checkpoint path construction.""" | ||
| path = get_checkpoint_path(Path("checkpoints"), 12345) | ||
| assert path == Path("checkpoints/12345/model.pt") | ||
|
|
||
| def test_custom_filename(self): | ||
| """Test with custom filename.""" | ||
| path = get_checkpoint_path(Path("weights"), 100, "state.pt") | ||
| assert path == Path("weights/100/state.pt") | ||
|
|
||
| @pytest.mark.parametrize( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use an example filename with zero padding in the step number, and consolidate the total number of test cases |
||
| ("directory", "step", "filename", "expected"), | ||
| [ | ||
| (Path("checkpoints"), 0, "model.pt", Path("checkpoints/0/model.pt")), | ||
| (Path("runs/exp1"), 1000, "checkpoint.pt", Path("runs/exp1/1000/checkpoint.pt")), | ||
| (Path("."), 42, "model.pt", Path("42/model.pt")), | ||
| ], | ||
| ) | ||
| def test_parametrized_paths(self, directory: Path, step: int, filename: str, expected: Path): | ||
ealt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Test various path combinations.""" | ||
| assert get_checkpoint_path(directory, step, filename) == expected | ||
|
|
||
|
|
||
| class TestFormatStepNumber: | ||
| """Test format_step_number function.""" | ||
|
|
||
| def test_basic_formatting(self): | ||
| """Test basic zero-padding behavior.""" | ||
| assert format_step_number(42, max_steps=100) == "042" | ||
| assert format_step_number(5, max_steps=1000) == "0005" | ||
|
|
||
| def test_no_padding_needed(self): | ||
| """Test when step already has maximum width.""" | ||
| assert format_step_number(999, max_steps=999) == "999" | ||
| assert format_step_number(100, max_steps=100) == "100" | ||
|
|
||
| def test_zero_step(self): | ||
| """Test formatting step 0.""" | ||
| assert format_step_number(0, max_steps=100) == "000" | ||
| assert format_step_number(0, max_steps=10000) == "00000" | ||
|
|
||
| @pytest.mark.parametrize( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. too many test cases, please consolidate |
||
| ("step", "max_steps", "expected"), | ||
| [ | ||
| (0, 999, "000"), | ||
| (1, 999, "001"), | ||
| (42, 999, "042"), | ||
| (999, 999, "999"), | ||
| (0, 100000, "000000"), | ||
| (42, 100000, "000042"), | ||
| (12345, 100000, "012345"), | ||
| (100000, 100000, "100000"), | ||
| ], | ||
| ) | ||
| def test_parametrized_formatting(self, step: int, max_steps: int, expected: str): | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Test various step and max_steps combinations.""" | ||
| assert format_step_number(step, max_steps) == expected | ||
|
|
||
| def test_lexicographic_ordering(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove |
||
| """Verify that formatted strings sort lexicographically.""" | ||
| max_steps = 10000 | ||
| formatted = [format_step_number(i, max_steps) for i in [1, 10, 100, 1000, 9999]] | ||
| assert formatted == sorted(formatted) | ||
|
|
||
| def test_width_computation(self): | ||
|
||
| """Verify format_step_number computes width correctly.""" | ||
| max_steps = 100000 | ||
| step = 42 | ||
| formatted = format_step_number(step, max_steps) | ||
| expected_width = len(str(max_steps)) | ||
| assert len(formatted) == expected_width | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make use of
format_step_numberhere so sorted checkpoint paths can naturally follow the ordering of the step number. Take an optionalmax_stepsargument after filename and useformat_step_numberif themax_stepsargument is supplied