-
Notifications
You must be signed in to change notification settings - Fork 2
Add configuration resolution and utility functions #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adamimos
wants to merge
18
commits into
main
Choose a base branch
from
sculptor/add-seq-len-calculator-util
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 9 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
c744d70
Create MLFlow persister
ealt b93d136
Fix lint issues
ealt ad18345
Use Unity Catelog
ealt 137e2f7
Add configuration resolution and utility functions
adamimos d9f9283
Add comprehensive test coverage for config resolution utilities
adamimos bf2a3a6
Apply ruff formatting to test_config_resolution.py
adamimos 5aa7887
Address PR review feedback from ealt
adamimos dd12490
Add default use_eos=False to compute_model_vocab_size for API consist…
adamimos 4ede1d1
Add comprehensive input validation to prevent production issues
adamimos 886b754
Address all PR review feedback from ealt
adamimos b63dbc5
Fix parse_checkpoint_step to handle .eqx extension
adamimos 2844fa6
Address PR review feedback
github-actions[bot] bb00983
Apply ruff formatting
adamimos 4416f0d
Switch from Unity Catalog to Workspace Model Registry
ealt 5b210c7
Add workspace fallback, document potential migration
ealt 85691a1
Create demo
ealt 12f52f7
Merge mlflow-persister branch to add MLFlowPersister and model regist…
adamimos ab0da89
Fix frozen instance error in MLFlowPersister
adamimos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| from pathlib import Path | ||
|
|
||
| SUPPORTED_EXTENSIONS = (".pt", ".eqx", ".pkl", ".ckpt", ".pth") | ||
|
|
||
|
|
||
| def _is_valid_checkpoint_filename(filename: str) -> bool: | ||
| """Check if filename is a valid checkpoint filename with supported extension. | ||
|
|
||
| Args: | ||
| filename: The checkpoint filename to validate | ||
|
|
||
| Returns: | ||
| True if filename has a supported extension, False otherwise | ||
|
|
||
| Examples: | ||
| >>> _is_valid_checkpoint_filename("model.pt") | ||
| True | ||
| >>> _is_valid_checkpoint_filename("state.eqx") | ||
| True | ||
| >>> _is_valid_checkpoint_filename("invalid.txt") | ||
| False | ||
| """ | ||
| return filename.endswith(SUPPORTED_EXTENSIONS) | ||
|
|
||
|
|
||
| def get_checkpoint_path( | ||
| directory: Path, step: int, filename: str = "model.pt", max_steps: int | None = None | ||
| ) -> Path: | ||
| """Construct checkpoint path following the standard naming convention. | ||
|
|
||
| Args: | ||
| directory: Base directory for checkpoints | ||
| step: Training step number (must be non-negative) | ||
| filename: Checkpoint filename (default: "model.pt") | ||
| max_steps: Maximum number of training steps. If provided, step will be zero-padded | ||
|
|
||
| Returns: | ||
| Path to checkpoint file: {directory}/{step}/{filename} | ||
|
|
||
| Raises: | ||
| ValueError: If step is negative or filename has unsupported extension | ||
|
|
||
| Examples: | ||
| >>> get_checkpoint_path(Path("checkpoints"), 12345) | ||
| PosixPath('checkpoints/12345/model.pt') | ||
| >>> get_checkpoint_path(Path("weights"), 100, "state.pt") | ||
| PosixPath('weights/100/state.pt') | ||
| >>> get_checkpoint_path(Path("checkpoints"), 42, max_steps=100000) | ||
| PosixPath('checkpoints/000042/model.pt') | ||
| """ | ||
| if step < 0: | ||
| raise ValueError(f"Step must be non-negative, got {step}") | ||
| if not _is_valid_checkpoint_filename(filename): | ||
| raise ValueError(f"Filename must have one of these extensions: {SUPPORTED_EXTENSIONS}, got {filename}") | ||
|
|
||
| if max_steps is not None: | ||
| step_str = format_step_number(step, max_steps) | ||
| else: | ||
| step_str = str(step) | ||
|
|
||
| return directory / step_str / filename | ||
|
|
||
|
|
||
| def parse_checkpoint_step(path: str) -> int | None: | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Extract training step number from checkpoint path. | ||
|
|
||
| Handles the format: {step}/model.pt or {step}/{filename} | ||
|
|
||
| Args: | ||
| path: File path or S3 key containing checkpoint | ||
|
|
||
| Returns: | ||
| Step number if found, None otherwise | ||
|
|
||
| Examples: | ||
| >>> parse_checkpoint_step("checkpoints/12345/model.pt") | ||
| 12345 | ||
| >>> parse_checkpoint_step("12345/model.pt") | ||
| 12345 | ||
| """ | ||
| parts = path.split("/") | ||
| if len(parts) >= 2 and _is_valid_checkpoint_filename(parts[-1]): | ||
| try: | ||
| return int(parts[-2]) | ||
| except ValueError: | ||
| pass | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def format_step_number(step: int, max_steps: int) -> str: | ||
ealt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Format step number with appropriate zero-padding. | ||
|
|
||
| Args: | ||
| step: Current training step | ||
| max_steps: Maximum number of training steps | ||
|
|
||
| Returns: | ||
| Zero-padded step string | ||
|
|
||
| Raises: | ||
| ValueError: If step is not between 0 and max_steps | ||
|
|
||
| Examples: | ||
| >>> format_step_number(42, max_steps=100000) | ||
| '000042' | ||
| >>> format_step_number(999, max_steps=999) | ||
| '999' | ||
| """ | ||
| if not 0 <= step <= max_steps: | ||
| raise ValueError(f"Step {step} must be between 0 and {max_steps}") | ||
| width = len(str(max_steps)) | ||
| return f"{step:0{width}d}" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| def compute_generator_sequence_length(model_n_ctx: int, *, use_bos: bool = False, use_eos: bool = False) -> int: | ||
| """Compute the generator's sequence length from model context length and special token usage. | ||
|
|
||
| The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS | ||
|
|
||
| Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS - EOS | ||
|
|
||
| Args: | ||
| model_n_ctx: The model's context length (number of input positions it processes) | ||
| use_bos: Whether a beginning-of-sequence token is prepended during data generation | ||
| use_eos: Whether an end-of-sequence token is appended during data generation | ||
|
|
||
| Returns: | ||
| The sequence length to configure for the data generator | ||
|
|
||
| Raises: | ||
| ValueError: If the resulting generator sequence length would be non-positive | ||
|
|
||
| Examples: | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False) | ||
| 512 | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=False) | ||
| 513 | ||
| >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) | ||
| 511 | ||
| """ | ||
| assert model_n_ctx > 0, f"model_n_ctx must be positive, got {model_n_ctx}" | ||
|
|
||
| result = model_n_ctx + 1 - int(use_bos) - int(use_eos) | ||
| if result <= 0: | ||
| raise ValueError( | ||
| f"Invalid configuration: model_n_ctx={model_n_ctx}, use_bos={use_bos}, use_eos={use_eos} " | ||
| f"results in non-positive generator sequence length ({result})" | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def compute_model_context_length(generator_seq_len: int, *, use_bos: bool = False, use_eos: bool = False) -> int: | ||
| """Compute the model's context length from generator sequence length and special token usage. | ||
|
|
||
| The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS | ||
|
|
||
| Args: | ||
| generator_seq_len: The sequence length configured for the data generator | ||
| use_bos: Whether a beginning-of-sequence token is prepended during data generation | ||
| use_eos: Whether an end-of-sequence token is appended during data generation | ||
|
|
||
| Returns: | ||
| The context length for the model (number of input positions it will process) | ||
|
|
||
| Raises: | ||
| ValueError: If the resulting model context length would be non-positive | ||
|
|
||
| Examples: | ||
| >>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False) | ||
| 512 | ||
| >>> compute_model_context_length(generator_seq_len=513, use_bos=False, use_eos=False) | ||
| 512 | ||
| >>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) | ||
| 512 | ||
| """ | ||
| assert generator_seq_len > 0, f"generator_seq_len must be positive, got {generator_seq_len}" | ||
|
|
||
| result = generator_seq_len - 1 + int(use_bos) + int(use_eos) | ||
| if result <= 0: | ||
| raise ValueError( | ||
| f"Invalid configuration: generator_seq_len={generator_seq_len}, use_bos={use_bos}, use_eos={use_eos} " | ||
| f"results in non-positive model context length ({result})" | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def compute_model_vocab_size(generator_vocab_size: int, *, use_bos: bool = False, use_eos: bool = False) -> int: | ||
| """Compute the model's vocabulary size from generator vocab and special tokens. | ||
|
|
||
| When BOS or EOS tokens are used during data generation, they are added to the vocabulary, | ||
| increasing the total vocab size the model needs to handle. | ||
|
|
||
| Args: | ||
| generator_vocab_size: The vocabulary size of the data generator | ||
| use_bos: Whether a beginning-of-sequence token is used during data generation | ||
| use_eos: Whether an end-of-sequence token is used during data generation | ||
|
|
||
| Returns: | ||
| The vocabulary size the model should be configured with | ||
|
|
||
| Raises: | ||
| ValueError: If generator_vocab_size is non-positive | ||
|
|
||
| Examples: | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) | ||
| 101 | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True) | ||
| 102 | ||
| >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) | ||
| 100 | ||
| """ | ||
| assert generator_vocab_size > 0, f"generator_vocab_size must be positive, got {generator_vocab_size}" | ||
| return generator_vocab_size + int(use_bos) + int(use_eos) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Among the set of all strings, there are more constrains on a valid filename than simply ending with an acceptable extension