diff --git a/.github/actions/install-ci-dependencies/action.yml b/.github/actions/install-ci-dependencies/action.yml index e84c949..b03d8bc 100644 --- a/.github/actions/install-ci-dependencies/action.yml +++ b/.github/actions/install-ci-dependencies/action.yml @@ -37,28 +37,37 @@ runs: # - requirements.txt excludes torch (dependencies resolved against nightly) # - Skip on macOS ARM64 - nightly +cpu wheels aren't available, use stable torch # - Skip for "oldest" builds - they use stable torch from lock file - - name: Install torch nightly + - name: Install torch prerelease (if configured) if: inputs.use_oldest != 'true' shell: bash run: | - NIGHTLY_FILE="requirements/ci/torch-nightly.txt" - if [[ -f "${NIGHTLY_FILE}" ]]; then - TORCH_VERSION=$(grep -v '^#' "${NIGHTLY_FILE}" | grep -v '^$' | head -1 || true) - if [[ -n "${TORCH_VERSION}" ]]; then + PRE_FILE="requirements/ci/torch-pre.txt" + if [[ -f "${PRE_FILE}" ]]; then + # Read 3-line config: version, CUDA target, channel type + # Use while loop for bash 3.2 compatibility (macOS) + PRE_CONFIG=() + while IFS= read -r line; do + PRE_CONFIG+=("$line") + done < <(grep -v '^#' "${PRE_FILE}" | grep -v '^$') + TORCH_VERSION="${PRE_CONFIG[0]}" + CUDA_TARGET="${PRE_CONFIG[1]}" + CHANNEL_TYPE="${PRE_CONFIG[2]}" + + if [[ -n "${TORCH_VERSION}" && -n "${CHANNEL_TYPE}" ]]; then if [[ "${{ runner.os }}" == "macOS" ]]; then - echo "Skipping torch nightly on macOS ARM64 (no +cpu wheels available)" + echo "Skipping torch ${CHANNEL_TYPE} on macOS ARM64 (no +cpu wheels available)" echo "Will install stable torch with --torch-backend=auto" else - echo "Installing torch nightly: ${TORCH_VERSION}+cpu" + echo "Installing torch ${CHANNEL_TYPE}: ${TORCH_VERSION}+cpu" uv pip install --prerelease=allow "torch==${TORCH_VERSION}+cpu" \ - --index-url https://download.pytorch.org/whl/nightly/cpu + --index-url "https://download.pytorch.org/whl/${CHANNEL_TYPE}/cpu" fi fi fi # Install FTS and all dependencies # - UV_OVERRIDE: applies Lightning commit pin from overrides.txt - # - When nightly configured (and not macOS): torch already installed, just install rest + # - When prerelease configured (and not macOS): torch already installed, just install rest # - For macOS: use --torch-backend=auto for MPS-compatible stable torch # - For stable builds: use --torch-backend=cpu - name: Install FTS and all dependencies @@ -74,27 +83,32 @@ runs: echo "Installing with latest versions..." fi - # Check if torch nightly is configured - NIGHTLY_FILE="requirements/ci/torch-nightly.txt" - TORCH_NIGHTLY="false" - if [[ -f "${NIGHTLY_FILE}" ]]; then - TORCH_VERSION=$(grep -v '^#' "${NIGHTLY_FILE}" | grep -v '^$' | head -1 || true) + # Check if torch prerelease is configured + PRE_FILE="requirements/ci/torch-pre.txt" + TORCH_PRERELEASE="false" + if [[ -f "${PRE_FILE}" ]]; then + # Use while loop for bash 3.2 compatibility (macOS) + PRE_CONFIG=() + while IFS= read -r line; do + PRE_CONFIG+=("$line") + done < <(grep -v '^#' "${PRE_FILE}" | grep -v '^$') + TORCH_VERSION="${PRE_CONFIG[0]}" if [[ -n "${TORCH_VERSION}" ]]; then - TORCH_NIGHTLY="true" + TORCH_PRERELEASE="true" fi fi # Determine install command based on platform and torch configuration # - Oldest builds: use --torch-backend=cpu for stable torch - # - macOS: use --torch-backend=auto (MPS), nightly not available - # - Linux/Windows with nightly: torch already installed, no backend flag needed + # - macOS: use --torch-backend=auto (MPS), prerelease not available + # - Linux/Windows with prerelease: torch already installed, no backend flag needed # - Linux/Windows with stable: use --torch-backend=cpu if [[ "${{ inputs.use_oldest }}" == "true" ]]; then INSTALL_CMD="uv pip install -e . -r ${REQ_FILE} --torch-backend=cpu" elif [[ "${{ runner.os }}" == "macOS" ]]; then INSTALL_CMD="uv pip install -e . -r ${REQ_FILE} --torch-backend=auto" - elif [[ "${TORCH_NIGHTLY}" == "true" ]]; then - # Torch nightly already installed, just install FTS and deps + elif [[ "${TORCH_PRERELEASE}" == "true" ]]; then + # Torch prerelease already installed, just install FTS and deps INSTALL_CMD="uv pip install -e . -r ${REQ_FILE}" else INSTALL_CMD="uv pip install -e . -r ${REQ_FILE} --torch-backend=cpu" diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index a60a8ba..2129feb 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -199,6 +199,43 @@ source ${FTS_VENV_BASE}/${FTS_TARGET_VENV}/bin/activate PL_RUN_STANDALONE_TESTS=1 python -m pytest tests/test_specific.py::test_function -v ``` +### Building Documentation + +**Documentation build commands (needs activated venv):** + +```bash +export FTS_VENV_BASE=/mnt/cache/${USER}/.venvs +export FTS_TARGET_VENV=fts_latest +export FTS_REPO_DIR=${HOME}/repos/finetuning-scheduler # Example: adjust to your local repo path +# Activate your environment first +cd ${FTS_REPO_DIR} && \ +source ${FTS_VENV_BASE}/${FTS_TARGET_VENV}/bin/activate + +# Clean previous builds +cd docs && make clean + +# Build HTML documentation with warnings as errors +make html --debug SPHINXOPTS="-W --keep-going" + +# Run linkcheck to verify all links +make linkcheck SPHINXOPTS="-W --keep-going" + +# Check for errors in linkcheck output +grep -i "error\|broken" build/linkcheck/output.txt || echo "No errors found in linkcheck" +``` + +**Documentation requirements:** + +- All documentation must build without warnings when using `-W` flag +- All internal and external links must be valid (verified by linkcheck) +- RST cross-references should use appropriate directives: + - `:class:` for class references + - `:meth:` for method references + - `:func:` for function references + - `:doc:` for document references + - `:ref:` for section references (requires explicit label like `.. _label_name:`) +- Sphinx autosummary generates API documentation from docstrings + ## Project Layout and Architecture ### Source Code Structure diff --git a/.gitignore b/.gitignore index 658f257..02bc844 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,8 @@ timit_data/ grid_generated* grid_ori* - +# we don't have repo-specific prompts at this juncture +.github/prompts/ # C extensions *.so diff --git a/README.md b/README.md index 472d695..696acbf 100644 --- a/README.md +++ b/README.md @@ -226,7 +226,7 @@ See the [versioning documentation](https://finetuning-scheduler.readthedocs.io/e
Current build statuses for Fine-Tuning Scheduler -| System / (PyTorch/Python ver) | 2.6.0/3.9 | 2.10.0/3.9, 2.10.0/3.12 | +| System / (PyTorch/Python ver) | 2.6.0/3.10 | 2.10.0/3.10, 2.10.0/3.12 | | :---------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | Linux \[GPUs\*\*\] | - | [![Build Status](https://dev.azure.com//speediedan/finetuning-scheduler/_apis/build/status/Multi-GPU%20&%20Example%20Tests?branchName=main)](https://dev.azure.com/speediedan/finetuning-scheduler/_build/latest?definitionId=1&branchName=main) | | Linux (Ubuntu 22.04) | [![Test](https://github.com/speediedan/finetuning-scheduler/actions/workflows/ci_test-full.yml/badge.svg?branch=main&event=push)](https://github.com/speediedan/finetuning-scheduler/actions/workflows/ci_test-full.yml) | [![Test](https://github.com/speediedan/finetuning-scheduler/actions/workflows/ci_test-full.yml/badge.svg?branch=main&event=push)](https://github.com/speediedan/finetuning-scheduler/actions/workflows/ci_test-full.yml) | diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index caa8e41..3e63426 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -81,13 +81,11 @@ RUN \ else \ # or target a specific cuda build, by specifying a particular index url w/... # ... default channel - # uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128; \ - # ... pytorch patch version - # uv pip install torch==1.11.1+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html; \ + # uv pip install torch --torch-backend=cu128; \ # ... pytorch nightly dev version - uv pip install --prerelease=allow torch==2.10.0.dev20251124 --index-url https://download.pytorch.org/whl/nightly/cu128; \ + # uv pip install --prerelease=allow torch==2.10.0.dev20251124 --index-url https://download.pytorch.org/whl/nightly/cu128; \ # ... test channel - # uv pip install --prerelease=allow torch==2.10.0 --index-url https://download.pytorch.org/whl/test/cu128; \ + uv pip install --prerelease=allow torch==2.10.0 --index-url https://download.pytorch.org/whl/test/cu128; \ fi && \ # We avoid installing Lightning and other dependencies here as they are usually upgraded anyway later in # CI but we may re-enable in the future. diff --git a/docs/source/distributed/fsdp_scheduled_fine_tuning.rst b/docs/source/distributed/fsdp_scheduled_fine_tuning.rst index 4c033f2..a9d886b 100644 --- a/docs/source/distributed/fsdp_scheduled_fine_tuning.rst +++ b/docs/source/distributed/fsdp_scheduled_fine_tuning.rst @@ -309,10 +309,6 @@ While not technically required, we add ``DebertaV2Embeddings`` separately as wel As always, if needed, one can alternatively override ``configure_model`` and manually wrap a given :external+pl:class:`~lightning.pytorch.core.module.LightningModule` to align with a desired fine-tuning schedule. -.. warning:: - - :class:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter` is in BETA and subject to change. The - interface can bring breaking changes and new features with the next release of PyTorch. .. note:: diff --git a/docs/source/index.rst b/docs/source/index.rst index ca6ff9a..57a3d4f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -523,6 +523,13 @@ Footnotes advanced/lr_scheduler_reinitialization advanced/optimizer_reinitialization +.. toctree:: + :maxdepth: 1 + :name: Plugins + :caption: Plugins + + plugins/strategy_adapter_entry_points + .. toctree:: :maxdepth: 1 :name: Basic Examples diff --git a/docs/source/plugins/strategy_adapter_entry_points.rst b/docs/source/plugins/strategy_adapter_entry_points.rst new file mode 100644 index 0000000..1554007 --- /dev/null +++ b/docs/source/plugins/strategy_adapter_entry_points.rst @@ -0,0 +1,252 @@ +.. _strategy_adapter_entry_points: + +Strategy Adapter Entry Points +############################## + +.. warning:: + This is an :ref:`experimental ` feature which is + still in development. The entry point API and plugin discovery mechanism may change in future + releases. + +Overview +******** + +Fine-Tuning Scheduler (FTS) supports custom strategy adapters via Python entry points, enabling third-party +packages to extend FTS with specialized adapters for custom training strategies, model architectures, +or parameter naming conventions (e.g. dynamic parameter views for latent space analysis). + +This plugin mechanism allows packages like `Interpretune `_ +to provide adapters that integrate seamlessly with FTS without requiring modifications to the FTS +codebase itself. + +**Important Concepts:** + +- **Lightning Strategy Flags**: Built-in PyTorch Lightning strategy identifiers (e.g., ``single_device``, ``ddp``, ``fsdp``) +- **Strategy Adapters**: Classes that extend FTS functionality for specific strategies or use cases +- **Custom Strategy Adapter Mapping**: User-provided dictionary mapping Lightning strategy flags to adapter implementations + +Note that custom strategy adapters are meant to **adapt existing Lightning strategies**, not create +wholly new ones. If you need a new strategy, register it with PyTorch Lightning first, then create +an adapter to extend FTS support for it. + +Entry Point Specification +************************** + +Entry Point Group +================= + +Custom strategy adapters are registered under the ``finetuning_scheduler.strategy_adapters`` entry +point group. + +Registration Format +=================== + +In your package's ``pyproject.toml``, register your adapter: + +.. code-block:: toml + + [project.entry-points."finetuning_scheduler.strategy_adapters"] + adapter_name = "package.module:AdapterClass" + +The entry point name (``adapter_name``) will be lowercased and used to reference the adapter. +The value should follow Python's standard entry point format: ``module:attribute`` or +``module.submodule:attribute``. + +Discovery and Loading +********************* + +Entry points are discovered lazily during strategy setup (at the start of training). The discovery process: + +1. Scans for all registered entry points in the ``finetuning_scheduler.strategy_adapters`` group +2. Attempts to load each adapter class +3. Adds successfully loaded adapters to the runtime ``STRATEGY_ADAPTERS`` mapping, keyed by the + lowercased entry point name +4. Logs warnings for any adapters that fail to load (without preventing FTS initialization) + +Usage Example +************* + +Real-World Example: TransformerBridge Adapter +============================================== + +The Interpretune package provides a ``TransformerBridgeStrategyAdapter`` that enables clean +TransformerLens-style parameter naming in FTS schedules. Here's how it's registered and used: + +**Registration** (in Interpretune's ``pyproject.toml``): + +.. code-block:: toml + + [project.entry-points."finetuning_scheduler.strategy_adapters"] + transformerbridge = "interpretune.adapters.transformer_lens:TransformerBridgeStrategyAdapter" + +**Usage** (in training configuration): + +.. code-block:: python + + from finetuning_scheduler import FinetuningScheduler + + # Map Lightning strategy flags to the adapter + # Multiple strategy flags can use the same adapter + fts = FinetuningScheduler( + custom_strategy_adapters={ + "single_device": "transformerbridge", # Use entry point name + "ddp": "transformerbridge", # Same adapter for DDP + # Or use fully qualified paths: + # "single_device": "interpretune.adapters.transformer_lens:TransformerBridgeStrategyAdapter", + # "ddp": "interpretune.adapters.transformer_lens.TransformerBridgeStrategyAdapter", + }, + strategy_adapter_cfg={"use_tl_names": True}, # Adapter-specific config + ) + +This allows fine-tuning schedules to use architecture-agnostic parameter names like ``blocks.9.attn.W_Q`` instead +of verbose and architecture-dependent canonical names, while FTS handles the necessary translations automatically. + +**Key Points:** + +- Strategy flags (``single_device``, ``ddp``, etc.) refer to Lightning's built-in strategies +- The same adapter can be mapped to multiple strategy flags +- Three formats are supported for referencing adapters: + + 1. Entry point name: ``\"transformerbridge\"`` + 2. Colon-separated path: ``\"interpretune.adapters.transformer_lens:TransformerBridgeStrategyAdapter\"`` + 3. Dot-separated path: ``\"interpretune.adapters.transformer_lens.TransformerBridgeStrategyAdapter\"`` + +Creating Custom Adapters +************************* + +Base Requirements +================= + +Custom strategy adapters should: + +1. Inherit from :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` +2. Implement required methods for your specific use case (see :doc:`/api/finetuning_scheduler.strategy_adapters`) +3. Follow the adapter lifecycle hooks (``:meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.connect``, + ``:meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.on_before_init_fts``, etc.) + +Override Points +=============== + +Strategy adapters can customize FTS behavior at multiple levels of abstraction accommodating a variety of use cases: + +**Parameter Naming** + Override :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.get_named_params_for_schedule_validation` + to provide custom parameter names while using default validation logic. + +**Full Validation** + Override :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.validate_ft_sched` + to completely customize schedule validation. + +**Schedule Generation** + Override :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.gen_ft_schedule` + to customize how default schedules are generated. + +**Checkpoint Handling** + Override ``:meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.before_restore_model()``, + ``lightning_module_state_dict()``, and ``load_model_state_dict()`` for custom checkpoint translation logic. + +See :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` for the complete API. + +Best Practices +************** + +Robust Loading +============== + +Entry point loading is wrapped in exception handling to prevent adapter failures from breaking FTS +initialization. However, adapters should: + +- Validate dependencies and raise clear errors during ``__init__()`` if requirements aren't met +- Use meaningful exception messages to help users diagnose configuration issues +- Document any required dependencies in your package documentation + +Naming Conventions +================== + +- Use descriptive, lowercase entry point names (e.g., ``transformerbridge``, ``custom_fsdp``) +- Avoid generic names that might conflict with other packages +- Consider prefixing with your package name for uniqueness (e.g., ``mypackage_adapter``) + +Configuration +============= + +Custom Adapter Mapping Format +------------------------------ + +The :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.custom_strategy_adapters` parameter +accepts a dictionary mapping PyTorch Lightning strategy flags (canonical strategy names like +``"single_device"``, ``"auto"``, ``"ddp"``, etc.) to adapter references. The adapter reference +can be: + +1. **An entry point name** (lowercased) registered under ``finetuning_scheduler.strategy_adapters`` +2. **A fully qualified class path** in the format ``"module.path:ClassName"`` or + ``"module.path.ClassName"`` + +This allows multiple strategy flags to be associated with the same adapter. For example: + +.. code-block:: python + + from finetuning_scheduler import FinetuningScheduler + + # Multiple strategies can use the same registered plugin adapter + fts = FinetuningScheduler( + custom_strategy_adapters={ + "single_device": "transformerbridge", # Plugin entry point name + # We can use the same plugin for multiple strategies, here we use a fully qualified path format as well + "auto": "interpretune.adapters.transformer_lens.TransformerBridgeStrategyAdapter", + }, + strategy_adapter_cfg={ + "use_tl_names": True, # Configuration passed to the adapter + }, + ) + +**Native FTS Adapters**: FTS includes built-in adapters in the ``STRATEGY_ADAPTERS`` mapping +that are always available: + +- ``"fsdp"`` - :class:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter` +- ``"modelparallelstrategy"`` - :class:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter` + +These can be referenced directly without requiring plugin registration. + +Adapter-Specific Configuration +------------------------------- + +If your adapter accepts configuration, use the ``strategy_adapter_cfg`` parameter: + +.. code-block:: python + + fts = FinetuningScheduler( + custom_strategy_adapters={"target_strategy": "my_adapter"}, + strategy_adapter_cfg={ + "option1": value1, + "option2": value2, + }, + ) + +Testing +======= + +Test your adapter with FTS by: + +1. Creating test fixtures that instantiate FTS with your adapter +2. Verifying schedule validation works with your parameter naming +3. Testing checkpoint save/restore if you override those methods +4. Ensuring your adapter works with both explicit and implicit schedules + +Future Directions +***************** + +This plugin system may be extended in future releases to support: + +- Versioned adapter APIs +- Additional extension points beyond strategy adapters + +Community adapters and feedback on the plugin system are welcome! Please share your use cases +and suggestions on the `GitHub repository `_. + +See Also +******** + +- :doc:`/api/finetuning_scheduler.strategy_adapters` +- :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` +- :ref:`versioning:API Stability Classifications` diff --git a/pyproject.toml b/pyproject.toml index d8282e7..15f5b93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,12 @@ authors = [{name = "Daniel Dale", email = "danny.dale@gmail.com"}] license = "Apache-2.0" license-files = ["LICENSE*"] requires-python = ">=3.10" + +# NOTE: Minimum versions for torch and Lightning are managed dynamically by setup.py +# See src/finetuning_scheduler/dynamic_versioning/utils.py for version specifications: +# - torch: defined in BASE_DEPENDENCIES +# - Lightning: defined in LIGHTNING_PACKAGES (unified/standalone) +# The 'dependencies' field above is marked as 'dynamic' and populated at build time. keywords = [ "deep learning", "pytorch", @@ -118,6 +124,17 @@ all = [ "traitlets>=5.0.0", ] +# ----------------------------------------------------------------------------- +# FTS Metadata - Minimum Version Requirements (informational only) +# ----------------------------------------------------------------------------- +# These values are for documentation purposes and are NOT used during installation. +# Actual minimum versions are enforced in src/finetuning_scheduler/dynamic_versioning/utils.py +# The build system (setup.py) reads from utils.py to populate the dynamic dependencies. +[tool.fts.min-versions] +torch = ">=2.6.0" +lightning = ">=2.6.0,<2.6.1" +python = "3.10" # Defined in pyproject.toml requires-python + [project.urls] "Homepage" = "https://github.com/speediedan/finetuning-scheduler" "Bug Tracker" = "https://github.com/speediedan/finetuning-scheduler/issues" diff --git a/requirements/ci/requirements-oldest.txt b/requirements/ci/requirements-oldest.txt index 20e13f0..1e8d69e 100644 --- a/requirements/ci/requirements-oldest.txt +++ b/requirements/ci/requirements-oldest.txt @@ -1,12 +1,12 @@ # This file was autogenerated by uv via the following command: -# uv pip compile /home/speediedan/repos/finetuning-scheduler/pyproject.toml --extra all --group dev --group test --output-file /home/speediedan/repos/finetuning-scheduler/requirements/ci/requirements-oldest.txt --no-strip-extras --resolution lowest-direct --universal --python-version 3.10 +# uv pip compile pyproject.toml --extra all --group dev --group test --output-file /home/speediedan/repos/finetuning-scheduler/requirements/ci/requirements-oldest.txt --no-strip-extras --resolution lowest-direct --universal --python-version 3.10 aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.13.2 +aiohttp==3.13.3 # via fsspec aiosignal==1.4.0 # via aiohttp -alembic==1.17.2 +alembic==1.18.0 # via mlflow annotated-doc==0.0.4 # via fastapi @@ -16,7 +16,7 @@ antlr4-python3-runtime==4.9.3 # via # hydra-core # omegaconf -anyio==4.12.0 +anyio==4.12.1 # via starlette appnope==0.1.4 ; sys_platform == 'darwin' # via @@ -49,10 +49,8 @@ bleach==6.3.0 blinker==1.9.0 # via flask cachetools==5.5.2 - # via - # google-auth - # mlflow-skinny -certifi==2025.11.12 + # via mlflow-skinny +certifi==2026.1.4 # via requests cffi==2.0.0 # via @@ -92,13 +90,13 @@ cryptography==46.0.3 ; sys_platform == 'linux' # via secretstorage cycler==0.12.1 # via matplotlib -databricks-sdk==0.73.0 +databricks-sdk==0.77.0 # via mlflow-skinny datasets==4.0.0 # via # finetuning-scheduler (pyproject.toml) # evaluate -debugpy==1.8.17 +debugpy==1.8.19 # via ipykernel decorator==5.2.1 # via ipython @@ -115,7 +113,7 @@ docker==7.1.0 # via mlflow docstring-parser==0.17.0 # via jsonargparse -docutils==0.22.3 +docutils==0.22.4 # via readme-renderer entrypoints==0.4 # via @@ -125,9 +123,9 @@ evaluate==0.3.0 # via finetuning-scheduler (pyproject.toml) exceptiongroup==1.3.1 ; python_full_version < '3.11' # via anyio -fastapi==0.124.0 +fastapi==0.128.0 # via mlflow-skinny -filelock==3.20.0 +filelock==3.20.3 # via # datasets # huggingface-hub @@ -136,7 +134,7 @@ filelock==3.20.0 # virtualenv flask==3.1.2 # via mlflow -fonttools==4.61.0 +fonttools==4.61.1 # via matplotlib frozenlist==1.8.0 # via @@ -152,9 +150,9 @@ fsspec[http]==2025.3.0 # torch gitdb==4.0.12 # via gitpython -gitpython==3.1.45 +gitpython==3.1.46 # via mlflow-skinny -google-auth==2.43.0 +google-auth==2.47.0 # via databricks-sdk graphene==3.4.3 # via mlflow @@ -186,7 +184,7 @@ idna==3.11 # anyio # requests # yarl -importlib-metadata==8.7.0 +importlib-metadata==8.7.1 # via # keyring # mlflow-skinny @@ -217,9 +215,9 @@ itsdangerous==2.2.0 # via flask jaraco-classes==3.4.0 # via keyring -jaraco-context==6.0.1 +jaraco-context==6.0.2 # via keyring -jaraco-functools==4.3.0 +jaraco-functools==4.4.0 # via keyring jedi==0.19.2 # via ipython @@ -233,7 +231,7 @@ jinja2==3.1.6 # nbconvert # notebook # torch -joblib==1.5.2 +joblib==1.5.3 # via scikit-learn jsonargparse[signatures]==4.27.7 # via finetuning-scheduler (pyproject.toml) @@ -282,7 +280,7 @@ markupsafe==3.0.3 # jinja2 # mako # werkzeug -matplotlib==3.10.7 +matplotlib==3.10.8 # via mlflow matplotlib-inline==0.1.7 # via @@ -337,11 +335,11 @@ nest-asyncio==1.6.0 # nbclient networkx==3.4.2 ; python_full_version < '3.11' # via torch -networkx==3.6 ; python_full_version >= '3.11' +networkx==3.6.1 ; python_full_version >= '3.11' # via torch nh3==0.3.2 # via readme-renderer -nodeenv==1.9.1 +nodeenv==1.10.0 # via # pre-commit # pyright @@ -362,7 +360,7 @@ numpy==2.2.6 ; python_full_version < '3.11' # tensorboardx # torchmetrics # transformers -numpy==2.3.5 ; python_full_version >= '3.11' +numpy==2.4.0 ; python_full_version >= '3.11' # via # contourpy # datasets @@ -414,14 +412,14 @@ omegaconf==2.2.3 # via # finetuning-scheduler (pyproject.toml) # hydra-core -opentelemetry-api==1.39.0 +opentelemetry-api==1.39.1 # via # mlflow-skinny # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-sdk==1.39.0 +opentelemetry-sdk==1.39.1 # via mlflow-skinny -opentelemetry-semantic-conventions==0.60b0 +opentelemetry-semantic-conventions==0.60b1 # via opentelemetry-sdk packaging==25.0 # via @@ -452,7 +450,7 @@ pexpect==4.9.0 ; sys_platform != 'win32' # via ipython pickleshare==0.7.5 # via ipython -pillow==12.0.0 +pillow==12.1.0 # via matplotlib pip==21.0 # via finetuning-scheduler (pyproject.toml) @@ -511,7 +509,7 @@ pygments==2.19.2 # nbconvert # readme-renderer # rich -pyparsing==3.2.5 +pyparsing==3.3.1 # via matplotlib pyright==1.1.390 # via @@ -604,7 +602,7 @@ scipy==1.16.3 ; python_full_version >= '3.11' # scikit-learn secretstorage==3.5.0 ; sys_platform == 'linux' # via keyring -send2trash==1.8.3 +send2trash==2.0.0 # via notebook sentencepiece==0.2.0 # via finetuning-scheduler (pyproject.toml) @@ -621,11 +619,11 @@ six==1.17.0 # python-dateutil smmap==5.0.2 # via gitdb -sqlalchemy==2.0.44 +sqlalchemy==2.0.45 # via # alembic # mlflow -sqlparse==0.5.4 +sqlparse==0.5.5 # via mlflow-skinny starlette==0.50.0 # via fastapi @@ -661,7 +659,7 @@ torchmetrics==1.8.2 # via # lightning # pytorch-lightning -tornado==6.5.2 +tornado==6.5.4 # via # ipykernel # jupyter-client @@ -729,16 +727,16 @@ typing-extensions==4.15.0 # virtualenv typing-inspection==0.4.2 # via pydantic -tzdata==2025.2 +tzdata==2025.3 # via pandas -urllib3==2.6.0 +urllib3==2.6.3 # via # docker # requests # responses -uvicorn==0.38.0 +uvicorn==0.40.0 # via mlflow-skinny -virtualenv==20.35.4 +virtualenv==20.36.1 # via pre-commit waitress==3.0.2 ; sys_platform == 'win32' # via mlflow @@ -746,7 +744,7 @@ wcwidth==0.2.14 # via prompt-toolkit webencodings==0.5.1 # via bleach -werkzeug==3.1.4 +werkzeug==3.1.5 # via flask widgetsnbextension==3.5.2 # via ipywidgets diff --git a/requirements/ci/requirements.txt b/requirements/ci/requirements.txt index 6a0e957..24c21b4 100644 --- a/requirements/ci/requirements.txt +++ b/requirements/ci/requirements.txt @@ -1,12 +1,12 @@ # This file was autogenerated by uv via the following command: -# uv pip compile /home/speediedan/repos/finetuning-scheduler/pyproject.toml --extra all --group dev --group test --output-file /home/speediedan/repos/finetuning-scheduler/requirements/ci/requirements.txt --no-strip-extras --resolution highest --universal --python-version 3.10 --prerelease=if-necessary-or-explicit --override /tmp/tmp.CNVCWjH0nH --index-strategy unsafe-best-match --no-emit-package torch +# uv pip compile pyproject.toml --extra all --group dev --group test --output-file /home/speediedan/repos/finetuning-scheduler/requirements/ci/requirements.txt --no-strip-extras --resolution highest --universal --python-version 3.10 --prerelease=if-necessary-or-explicit --override /tmp/tmp.qrV2Vucuzc --index-strategy unsafe-best-match --no-emit-package torch aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.13.2 +aiohttp==3.13.3 # via fsspec aiosignal==1.4.0 # via aiohttp -alembic==1.17.2 +alembic==1.18.0 # via mlflow annotated-doc==0.0.4 # via fastapi @@ -16,7 +16,7 @@ antlr4-python3-runtime==4.9.3 # via # hydra-core # omegaconf -anyio==4.12.0 +anyio==4.12.1 # via # httpx # jupyter-server @@ -50,12 +50,11 @@ bleach[css]==6.3.0 # via nbconvert blinker==1.9.0 # via flask -cachetools==6.2.2 +cachetools==6.2.4 # via - # google-auth # mlflow-skinny # mlflow-tracing -certifi==2025.11.12 +certifi==2026.1.4 # via # httpcore # httpx @@ -90,7 +89,7 @@ contourpy==1.3.2 ; python_full_version < '3.11' # via matplotlib contourpy==1.3.3 ; python_full_version >= '3.11' # via matplotlib -coverage==7.12.0 +coverage==7.13.1 # via # finetuning-scheduler (pyproject.toml:dev) # finetuning-scheduler (pyproject.toml:test) @@ -101,15 +100,15 @@ cryptography==46.0.3 # secretstorage cycler==0.12.1 # via matplotlib -databricks-sdk==0.73.0 +databricks-sdk==0.77.0 # via # mlflow-skinny # mlflow-tracing -datasets==4.4.1 +datasets==4.4.2 # via # finetuning-scheduler (pyproject.toml) # evaluate -debugpy==1.8.17 +debugpy==1.8.19 # via ipykernel decorator==5.2.1 # via ipython @@ -126,7 +125,7 @@ docker==7.1.0 # via mlflow docstring-parser==0.17.0 # via jsonargparse -docutils==0.22.3 +docutils==0.22.4 # via readme-renderer evaluate==0.4.6 # via finetuning-scheduler (pyproject.toml) @@ -137,11 +136,11 @@ exceptiongroup==1.3.1 ; python_full_version < '3.11' # pytest executing==2.2.1 # via stack-data -fastapi==0.124.0 +fastapi==0.128.0 # via mlflow-skinny fastjsonschema==2.21.2 # via nbformat -filelock==3.20.0 +filelock==3.20.3 # via # datasets # huggingface-hub @@ -152,9 +151,9 @@ flask==3.1.2 # via # flask-cors # mlflow -flask-cors==6.0.1 +flask-cors==6.0.2 # via mlflow -fonttools==4.61.0 +fonttools==4.61.1 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -172,9 +171,9 @@ fsspec[http]==2025.10.0 # torch gitdb==4.0.12 # via gitpython -gitpython==3.1.45 +gitpython==3.1.46 # via mlflow-skinny -google-auth==2.43.0 +google-auth==2.47.0 # via databricks-sdk graphene==3.4.3 # via mlflow @@ -200,7 +199,7 @@ httpx==0.28.1 # via # datasets # jupyterlab -huey==2.5.5 +huey==2.6.0 # via mlflow huggingface-hub==0.36.0 # via @@ -221,7 +220,7 @@ idna==3.11 # jsonschema # requests # yarl -importlib-metadata==8.7.0 +importlib-metadata==8.7.1 # via # keyring # mlflow-skinny @@ -235,12 +234,12 @@ ipykernel==7.1.0 # finetuning-scheduler (pyproject.toml) # jupyterlab # nbval -ipython==8.37.0 ; python_full_version < '3.11' +ipython==8.38.0 ; python_full_version < '3.11' # via # finetuning-scheduler (pyproject.toml) # ipykernel # ipywidgets -ipython==9.8.0 ; python_full_version >= '3.11' +ipython==9.9.0 ; python_full_version >= '3.11' # via # finetuning-scheduler (pyproject.toml) # ipykernel @@ -255,9 +254,9 @@ itsdangerous==2.2.0 # via flask jaraco-classes==3.4.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via keyring -jaraco-context==6.0.1 ; platform_machine != 'ppc64le' and platform_machine != 's390x' +jaraco-context==6.0.2 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via keyring -jaraco-functools==4.3.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' +jaraco-functools==4.4.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via keyring jedi==0.19.2 # via ipython @@ -273,22 +272,22 @@ jinja2==3.1.6 # jupyterlab-server # nbconvert # torch -joblib==1.5.2 +joblib==1.5.3 # via scikit-learn -json5==0.12.1 +json5==0.13.0 # via jupyterlab-server jsonargparse[signatures]==4.41.0 # via finetuning-scheduler (pyproject.toml) jsonpointer==3.0.0 # via jsonschema -jsonschema[format-nongpl]==4.25.1 +jsonschema[format-nongpl]==4.26.0 # via # jupyter-events # jupyterlab-server # nbformat jsonschema-specifications==2025.9.1 # via jsonschema -jupyter-client==8.6.3 +jupyter-client==8.8.0 # via # finetuning-scheduler (pyproject.toml) # ipykernel @@ -318,7 +317,7 @@ jupyter-server==2.17.0 # notebook-shim jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.5.0 +jupyterlab==4.5.1 # via notebook jupyterlab-pygments==0.3.0 # via nbconvert @@ -357,7 +356,7 @@ markupsafe==3.0.3 # mako # nbconvert # werkzeug -matplotlib==3.10.7 +matplotlib==3.10.8 # via mlflow matplotlib-inline==0.2.1 # via @@ -367,15 +366,15 @@ mdit-py-plugins==0.5.0 # via jupytext mdurl==0.1.2 # via markdown-it-py -mistune==3.1.4 +mistune==3.2.0 # via nbconvert -mlflow==3.7.0 +mlflow==3.8.1 # via # finetuning-scheduler (pyproject.toml:dev) # finetuning-scheduler (pyproject.toml:test) -mlflow-skinny==3.7.0 +mlflow-skinny==3.8.1 # via mlflow -mlflow-tracing==3.7.0 +mlflow-tracing==3.8.1 # via mlflow more-itertools==10.8.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via @@ -389,7 +388,7 @@ multiprocess==0.70.18 # via # datasets # evaluate -nbclient==0.10.2 +nbclient==0.10.4 # via # finetuning-scheduler (pyproject.toml) # nbconvert @@ -411,11 +410,11 @@ nest-asyncio==1.6.0 # via ipykernel nh3==0.3.2 # via readme-renderer -nodeenv==1.9.1 +nodeenv==1.10.0 # via # pre-commit # pyright -notebook==7.5.0 +notebook==7.5.1 # via finetuning-scheduler (pyproject.toml) notebook-shim==0.2.4 # via @@ -434,7 +433,7 @@ numpy==2.2.6 ; python_full_version < '3.11' # tensorboardx # torchmetrics # transformers -numpy==2.3.5 ; python_full_version >= '3.11' +numpy==2.4.0 ; python_full_version >= '3.11' # via # contourpy # datasets @@ -451,21 +450,21 @@ omegaconf==2.3.0 # via # finetuning-scheduler (pyproject.toml) # hydra-core -opentelemetry-api==1.39.0 +opentelemetry-api==1.39.1 # via # mlflow-skinny # mlflow-tracing # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-proto==1.39.0 +opentelemetry-proto==1.39.1 # via # mlflow-skinny # mlflow-tracing -opentelemetry-sdk==1.39.0 +opentelemetry-sdk==1.39.1 # via # mlflow-skinny # mlflow-tracing -opentelemetry-semantic-conventions==0.60b0 +opentelemetry-semantic-conventions==0.60b1 # via opentelemetry-sdk overrides==7.7.0 ; python_full_version < '3.12' # via jupyter-server @@ -506,7 +505,7 @@ parso==0.8.5 # via jedi pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' # via ipython -pillow==12.0.0 +pillow==12.1.0 # via matplotlib pip==25.3 # via finetuning-scheduler (pyproject.toml) @@ -516,7 +515,7 @@ platformdirs==4.5.1 # virtualenv pluggy==1.6.0 # via pytest -pre-commit==4.5.0 +pre-commit==4.5.1 # via # finetuning-scheduler (pyproject.toml:dev) # finetuning-scheduler (pyproject.toml:test) @@ -535,7 +534,7 @@ protobuf==6.33.2 # mlflow-tracing # opentelemetry-proto # tensorboardx -psutil==7.1.3 +psutil==7.2.1 # via # finetuning-scheduler (pyproject.toml) # ipykernel @@ -572,13 +571,13 @@ pygments==2.19.2 # pytest # readme-renderer # rich -pyparsing==3.2.5 +pyparsing==3.3.1 # via matplotlib -pyright==1.1.407 +pyright==1.1.408 # via # finetuning-scheduler (pyproject.toml:dev) # finetuning-scheduler (pyproject.toml:test) -pytest==9.0.1 +pytest==9.0.2 # via # finetuning-scheduler (pyproject.toml:dev) # finetuning-scheduler (pyproject.toml:test) @@ -678,7 +677,11 @@ rsa==4.9.1 # via google-auth safetensors==0.7.0 # via transformers -scikit-learn==1.7.2 +scikit-learn==1.7.2 ; python_full_version < '3.11' + # via + # finetuning-scheduler (pyproject.toml) + # mlflow +scikit-learn==1.8.0 ; python_full_version >= '3.11' # via # finetuning-scheduler (pyproject.toml) # mlflow @@ -692,7 +695,7 @@ scipy==1.16.3 ; python_full_version >= '3.11' # scikit-learn secretstorage==3.5.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' # via keyring -send2trash==1.8.3 +send2trash==2.0.0 # via jupyter-server sentencepiece==0.2.1 # via finetuning-scheduler (pyproject.toml) @@ -707,13 +710,13 @@ six==1.17.0 # rfc3339-validator smmap==5.0.2 # via gitdb -soupsieve==2.8 +soupsieve==2.8.1 # via beautifulsoup4 -sqlalchemy==2.0.44 +sqlalchemy==2.0.45 # via # alembic # mlflow -sqlparse==0.5.4 +sqlparse==0.5.5 # via mlflow-skinny stack-data==0.6.3 # via ipython @@ -731,7 +734,7 @@ threadpoolctl==3.6.0 # via scikit-learn tinycss2==1.4.0 # via bleach -tokenizers==0.22.1 +tokenizers==0.22.2 # via transformers tomli==2.3.0 ; python_full_version < '3.11' # via @@ -743,7 +746,7 @@ torchmetrics==1.8.2 # via # lightning # pytorch-lightning -tornado==6.5.2 +tornado==6.5.4 # via # ipykernel # jupyter-client @@ -817,20 +820,20 @@ typing-extensions==4.15.0 # virtualenv typing-inspection==0.4.2 # via pydantic -tzdata==2025.2 +tzdata==2025.3 # via # arrow # pandas uri-template==1.3.0 # via jsonschema -urllib3==2.6.0 +urllib3==2.6.3 # via # docker # requests # twine -uvicorn==0.38.0 +uvicorn==0.40.0 # via mlflow-skinny -virtualenv==20.35.4 +virtualenv==20.36.1 # via pre-commit waitress==3.0.2 ; sys_platform == 'win32' # via mlflow @@ -844,7 +847,7 @@ webencodings==0.5.1 # tinycss2 websocket-client==1.9.0 # via jupyter-server -werkzeug==3.1.4 +werkzeug==3.1.5 # via # flask # flask-cors diff --git a/requirements/ci/torch-nightly.txt b/requirements/ci/torch-nightly.txt deleted file mode 100644 index 1030b4a..0000000 --- a/requirements/ci/torch-nightly.txt +++ /dev/null @@ -1,11 +0,0 @@ -# PyTorch nightly version configuration -# To enable nightly builds, uncomment and set the version and CUDA target below. -# Leave commented/empty to use stable PyTorch releases. -# -# Format: -# Line 1: torch version (e.g., 2.10.0.dev20251203) -# Line 2: CUDA target for local builds (e.g., cu128) - CI always uses cpu -# -# Example (uncomment both lines to enable): -2.10.0.dev20251124 -cu128 diff --git a/requirements/ci/torch-pre.txt b/requirements/ci/torch-pre.txt new file mode 100644 index 0000000..3418e47 --- /dev/null +++ b/requirements/ci/torch-pre.txt @@ -0,0 +1,18 @@ +# PyTorch prerelease version configuration +# To enable prerelease builds, uncomment and set the values below. +# Leave commented/empty to use stable PyTorch releases. +# +# Format: +# Line 1: torch version (e.g., 2.10.0 for test/RC, 2.10.0.dev20251203 for nightly) +# Line 2: CUDA target for local builds (e.g., cu128) - CI always uses cpu +# Line 3: channel type: "test" or "nightly" +# +# Example for test/RC channel (currently enabled): +2.10.0 +cu128 +test + +# Example for nightly (commented out): +# 2.10.0.dev20251124 +# cu128 +# nightly diff --git a/requirements/ci/torch_override.txt b/requirements/ci/torch_override.txt index 3a916b2..b98382d 100644 --- a/requirements/ci/torch_override.txt +++ b/requirements/ci/torch_override.txt @@ -1,16 +1,16 @@ -# PyTorch nightly override for UV_OVERRIDE -# Generated by lock_ci_requirements.sh from torch-nightly.txt +# PyTorch prerelease override for UV_OVERRIDE +# Generated by lock_ci_requirements.sh from torch-pre.txt # -# Manual installation with nightly (two-step approach for security): -# Step 1: Install PyTorch nightly -# uv pip install --prerelease=if-necessary-or-explicit torch==2.10.0.dev20251124 --index-url https://download.pytorch.org/whl/nightly/cu128 +# Manual installation with prerelease (two-step approach for security): +# Step 1: Install PyTorch prerelease (edit torch-pre.txt to configure channel) +# uv pip install --prerelease=if-necessary-or-explicit torch==2.10.0 --index-url https://download.pytorch.org/whl/nightly/cu128 # # Step 2: Install FTS with Lightning commit pin (torch already installed, will be skipped) # export UV_OVERRIDE=${PWD}/requirements/ci/overrides.txt # uv pip install -e ".[all]" # # Or with locked requirements: -# uv pip install --prerelease=if-necessary-or-explicit torch==2.10.0.dev20251124 --index-url https://download.pytorch.org/whl/nightly/cu128 +# uv pip install --prerelease=if-necessary-or-explicit torch==2.10.0 --index-url https://download.pytorch.org/whl/nightly/cu128 # UV_OVERRIDE=${PWD}/requirements/ci/overrides.txt uv pip install -e . -r requirements/ci/requirements.txt -torch==2.10.0.dev20251124 +torch==2.10.0 diff --git a/requirements/utils/lock_ci_requirements.sh b/requirements/utils/lock_ci_requirements.sh index e6ad2fa..4df0ac1 100755 --- a/requirements/utils/lock_ci_requirements.sh +++ b/requirements/utils/lock_ci_requirements.sh @@ -10,12 +10,12 @@ # which allows uv to properly resolve oldest compatible versions. # # Torch handling: -# - When torch-nightly.txt is configured: -# - Lock file is generated with torch pinned to the nightly version -# - Uses PyTorch nightly index for resolution -# - Docker image and CI both use the same nightly version +# - When torch-pre.txt is configured: +# - Lock file is generated with torch pinned to the prerelease version +# - Uses PyTorch nightly or test index for resolution (based on channel type) +# - Docker image and CI both use the same prerelease version # - Post-processing prunes torch-only dependencies (see prune_torch_only_deps) -# - Without torch-nightly.txt: +# - Without torch-pre.txt: # - Uses stable torch from PyPI # - CI uses --torch-backend=cpu for CPU variant set -eo pipefail @@ -27,7 +27,7 @@ unset UV_OVERRIDE SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" CI_DIR="${REPO_ROOT}/requirements/ci" -TORCH_NIGHTLY_FILE="${CI_DIR}/torch-nightly.txt" +TORCH_PRE_FILE="${CI_DIR}/torch-pre.txt" TORCH_OVERRIDE_FILE="${CI_DIR}/torch_override.txt" # Ensure output directory exists @@ -35,14 +35,17 @@ mkdir -p "${CI_DIR}" echo "Generating locked CI requirements from pyproject.toml..." -# Check if torch nightly is configured -# Returns: torch_version if nightly is enabled, empty string otherwise -get_torch_nightly_version() { - if [[ -f "${TORCH_NIGHTLY_FILE}" ]]; then - # Read first non-comment, non-empty line as torch version - local version=$(grep -v '^#' "${TORCH_NIGHTLY_FILE}" | grep -v '^$' | head -1) - if [[ -n "${version}" ]]; then - echo "${version}" +# Check if torch prerelease is configured +# Returns: "version:channel" if prerelease is enabled, empty string otherwise +get_torch_pre_config() { + if [[ -f "${TORCH_PRE_FILE}" ]]; then + # Read 3-line config: version, CUDA target, channel type + readarray -t PRE_CONFIG < <(grep -v '^#' "${TORCH_PRE_FILE}" | grep -v '^$') + local version="${PRE_CONFIG[0]}" + local cuda="${PRE_CONFIG[1]}" + local channel="${PRE_CONFIG[2]}" + if [[ -n "${version}" && -n "${channel}" ]]; then + echo "${version}:${channel}" return fi fi @@ -65,8 +68,15 @@ get_torch_nightly_version() { # which ensures uv can properly resolve oldest compatible versions with # --resolution=lowest-direct without needing external constraint files. -# Check if torch nightly is configured -TORCH_NIGHTLY_VERSION=$(get_torch_nightly_version) +# Check if torch prerelease is configured +TORCH_PRE_CONFIG=$(get_torch_pre_config) +if [[ -n "${TORCH_PRE_CONFIG}" ]]; then + TORCH_PRE_VERSION="${TORCH_PRE_CONFIG%%:*}" # Extract version before ':' + TORCH_PRE_CHANNEL="${TORCH_PRE_CONFIG##*:}" # Extract channel after ':' +else + TORCH_PRE_VERSION="" + TORCH_PRE_CHANNEL="" +fi # Prune packages that are ONLY dependencies of torch from the lockfile. # This reduces the dependency confusion attack surface when using unsafe-best-match @@ -77,33 +87,33 @@ prune_torch_only_deps() { python "${SCRIPT_DIR}/prune_torch_deps.py" "${lockfile}" } -# Generate/update torch_override.txt if nightly is configured, remove if not +# Generate/update torch_override.txt if prerelease is configured, remove if not generate_torch_override() { - if [[ -n "${TORCH_NIGHTLY_VERSION}" ]]; then + if [[ -n "${TORCH_PRE_VERSION}" ]]; then cat > "${TORCH_OVERRIDE_FILE}" << EOF -# PyTorch nightly override for UV_OVERRIDE -# Generated by lock_ci_requirements.sh from torch-nightly.txt +# PyTorch prerelease override for UV_OVERRIDE +# Generated by lock_ci_requirements.sh from torch-pre.txt # -# Manual installation with nightly (two-step approach for security): -# Step 1: Install PyTorch nightly -# uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_NIGHTLY_VERSION} --index-url https://download.pytorch.org/whl/nightly/cu128 +# Manual installation with prerelease (two-step approach for security): +# Step 1: Install PyTorch prerelease (edit torch-pre.txt to configure channel) +# uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_PRE_VERSION} --index-url https://download.pytorch.org/whl/nightly/cu128 # # Step 2: Install FTS with Lightning commit pin (torch already installed, will be skipped) # export UV_OVERRIDE=\${PWD}/requirements/ci/overrides.txt # uv pip install -e ".[all]" # # Or with locked requirements: -# uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_NIGHTLY_VERSION} --index-url https://download.pytorch.org/whl/nightly/cu128 +# uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_PRE_VERSION} --index-url https://download.pytorch.org/whl/nightly/cu128 # UV_OVERRIDE=\${PWD}/requirements/ci/overrides.txt uv pip install -e . -r requirements/ci/requirements.txt -torch==${TORCH_NIGHTLY_VERSION} +torch==${TORCH_PRE_VERSION} EOF - echo "✓ Generated ${TORCH_OVERRIDE_FILE} (torch==${TORCH_NIGHTLY_VERSION})" + echo "✓ Generated ${TORCH_OVERRIDE_FILE} (torch==${TORCH_PRE_VERSION})" else - # Remove stale override file if nightly is disabled + # Remove stale override file if prerelease is disabled if [[ -f "${TORCH_OVERRIDE_FILE}" ]]; then rm -f "${TORCH_OVERRIDE_FILE}" - echo "✓ Removed ${TORCH_OVERRIDE_FILE} (nightly disabled)" + echo "✓ Removed ${TORCH_OVERRIDE_FILE} (prerelease disabled)" fi fi } @@ -112,14 +122,17 @@ generate_lockfile() { local resolution=$1 local output_file=$2 local python_version=$3 - local use_nightly=$4 # "true" to use torch nightly + local use_prerelease=$4 # "true" to use torch prerelease echo "Generating ${output_file} with resolution=${resolution}, python=${python_version}..." + # Change to repo root for dependency group resolution + pushd "${REPO_ROOT}" > /dev/null + # Build the base compile command local compile_cmd=( uv pip compile - "${REPO_ROOT}/pyproject.toml" + pyproject.toml --extra all --group dev --group test @@ -131,42 +144,45 @@ generate_lockfile() { --python-version "${python_version}" ) - # When using torch nightly: - # 1. Create a temporary override file to pin torch to the nightly version for dependency resolution + # When using torch prerelease (nightly or test): + # 1. Create a temporary override file to pin torch to the prerelease version for dependency resolution # 2. Use --prerelease=if-necessary-or-explicit to only allow prereleases for explicitly specified packages (torch) # or where all versions of the package are pre-release - # 3. Use --index with nightly CPU index for torch resolution + # 3. Use --index with prerelease CPU index for torch resolution (nightly or test channel) # 4. Use --index-strategy=unsafe-best-match for lockfile GENERATION only # This is required because with first-index (default), uv would either: # - Find torch on PyPI first (no nightly version), or # - Find scipy/etc on nightly index first (missing versions) # # Security rationale for using unsafe-best-match during lockfile generation: - # a) User INSTALLATION uses a secure two-step approach, only ever installing torch nightly - # from the explicitly specified nightly index (no unsafe-best-match at install time) + # a) User INSTALLATION uses a secure two-step approach, only ever installing torch prerelease + # from the explicitly specified prerelease index (no unsafe-best-match at install time) # b) The marginal dependency confusion attack surface is limited to the closely monitored - # PyTorch nightly index, which is maintained by PyTorch team. Post-processing prunes any packages that are + # PyTorch prerelease index, which is maintained by PyTorch team. Post-processing prunes any packages that are # ONLY dependencies of torch, eliminating potential attack vectors from torch-exclusive dependencies that - # might only exist on the nightly index. If a package is shared with other dependencies, it's already + # might only exist on the prerelease index. If a package is shared with other dependencies, it's already # being resolved from PyPI and subject to normal security scanning. # c) Lockfile generation runs on maintainer machines, not user machines # d) Generated lockfile pins exact package versions from PyPI # # 5. Use --no-emit-package=torch to exclude torch from output (installed separately with backend) # 6. Post-process to prune torch-only dependencies (see prune_torch_only_deps) - if [[ "${use_nightly}" == "true" && -n "${TORCH_NIGHTLY_VERSION}" ]]; then + if [[ "${use_prerelease}" == "true" && -n "${TORCH_PRE_VERSION}" ]]; then + # Determine channel from global variable + local channel="${TORCH_PRE_CHANNEL:-nightly}" # Default to nightly if not set + local torch_override_file=$(mktemp) - echo "torch==${TORCH_NIGHTLY_VERSION}" > "${torch_override_file}" + echo "torch==${TORCH_PRE_VERSION}" > "${torch_override_file}" compile_cmd+=( --prerelease=if-necessary-or-explicit --override "${torch_override_file}" - --index "https://download.pytorch.org/whl/nightly/cpu" + --index "https://download.pytorch.org/whl/${channel}/cpu" --index-strategy unsafe-best-match # for lockfile generation only, see comment above --no-emit-package torch ) - echo " Using torch nightly: ${TORCH_NIGHTLY_VERSION} (excluded from output, dependencies resolved)" + echo " Using torch ${channel}: ${TORCH_PRE_VERSION} (excluded from output, dependencies resolved)" "${compile_cmd[@]}" rm -f "${torch_override_file}" @@ -174,11 +190,14 @@ generate_lockfile() { # Prune torch-only dependencies to minimize dependency confusion attack surface prune_torch_only_deps "${output_file}" - echo "✓ Generated ${output_file} (torch ${TORCH_NIGHTLY_VERSION} excluded, torch-only deps pruned)" + echo "✓ Generated ${output_file} (torch ${TORCH_PRE_VERSION} excluded, torch-only deps pruned)" else "${compile_cmd[@]}" echo "✓ Generated ${output_file}" fi + + # Return to original directory + popd > /dev/null } # Generate both lock files @@ -186,21 +205,25 @@ generate_lockfile() { # packages like contourpy that have Python version requirements # - Oldest: Python 3.10 (minimum supported), lowest resolution # -# When torch nightly is configured: +# When torch prerelease (nightly or test) is configured: # - requirements.txt excludes torch (installed separately with appropriate backend) -# - torch dependencies are still resolved against the nightly version +# - torch dependencies are still resolved against the prerelease version # - requirements-oldest.txt uses stable torch (for minimum version testing) -USE_NIGHTLY="false" -if [[ -n "${TORCH_NIGHTLY_VERSION}" ]]; then - USE_NIGHTLY="true" - echo "Torch nightly mode: ${TORCH_NIGHTLY_VERSION}" +USE_PRERELEASE="false" +if [[ -n "${TORCH_PRE_VERSION}" ]]; then + USE_PRERELEASE="true" + echo "Torch ${TORCH_PRE_CHANNEL} mode: ${TORCH_PRE_VERSION}" fi # Generate torch override file for manual installation generate_torch_override -generate_lockfile "highest" "${CI_DIR}/requirements.txt" "3.10" "${USE_NIGHTLY}" +# Sync metadata from utils.py to pyproject.toml before generating lockfiles +echo "Syncing metadata from utils.py to pyproject.toml..." +python "${SCRIPT_DIR}/sync_metadata.py" + +generate_lockfile "highest" "${CI_DIR}/requirements.txt" "3.10" "${USE_PRERELEASE}" generate_lockfile "lowest-direct" "${CI_DIR}/requirements-oldest.txt" "3.10" "false" echo "" @@ -208,29 +231,29 @@ echo "Generated lock files:" echo " - ${CI_DIR}/requirements.txt (highest resolution, for latest tests)" echo " - ${CI_DIR}/requirements-oldest.txt (lowest resolution, for oldest tests)" echo "" -if [[ -n "${TORCH_NIGHTLY_VERSION}" ]]; then - echo "⚠️ Torch nightly mode: ${TORCH_NIGHTLY_VERSION}" - echo " requirements.txt excludes torch (dependencies resolved against nightly)" +if [[ -n "${TORCH_PRE_VERSION}" ]]; then + echo "⚠️ Torch ${TORCH_PRE_CHANNEL} mode: ${TORCH_PRE_VERSION}" + echo " requirements.txt excludes torch (dependencies resolved against ${TORCH_PRE_CHANNEL})" echo "" echo "Generated override file:" - echo " - ${CI_DIR}/torch_override.txt (for manual nightly installation reference)" + echo " - ${CI_DIR}/torch_override.txt (for manual prerelease installation reference)" echo "" - echo "Manual installation with nightly (two-step approach):" - echo " 1. uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_NIGHTLY_VERSION} --index-url https://download.pytorch.org/whl/nightly/cu128" + echo "Manual installation with prerelease (two-step approach):" + echo " 1. uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_PRE_VERSION} --index-url https://download.pytorch.org/whl/${TORCH_PRE_CHANNEL}/cu128" echo " 2. UV_OVERRIDE=requirements/ci/overrides.txt uv pip install -e \".[all]\"" echo "" echo "Or with locked requirements:" - echo " 1. uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_NIGHTLY_VERSION} --index-url https://download.pytorch.org/whl/nightly/cu128" + echo " 1. uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_PRE_VERSION} --index-url https://download.pytorch.org/whl/${TORCH_PRE_CHANNEL}/cu128" echo " 2. UV_OVERRIDE=requirements/ci/overrides.txt uv pip install -e . -r requirements/ci/requirements.txt" echo "" echo "Docker image installation (CUDA):" - echo " Ensure Dockerfile installs: torch==${TORCH_NIGHTLY_VERSION} from nightly/cu128 index" + echo " Ensure Dockerfile installs: torch==${TORCH_PRE_VERSION} from ${TORCH_PRE_CHANNEL}/cu128 index" echo "" echo "Azure Pipelines (Docker with pre-installed torch):" echo " UV_OVERRIDE=requirements/ci/overrides.txt uv pip install -e . -r requirements/ci/requirements.txt" echo "" echo "GitHub Actions (CPU):" - echo " 1. uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_NIGHTLY_VERSION}+cpu --index-url https://download.pytorch.org/whl/nightly/cpu" + echo " 1. uv pip install --prerelease=if-necessary-or-explicit torch==${TORCH_PRE_VERSION}+cpu --index-url https://download.pytorch.org/whl/${TORCH_PRE_CHANNEL}/cpu" echo " 2. UV_OVERRIDE=requirements/ci/overrides.txt uv pip install -e . -r requirements/ci/requirements.txt" else echo "Standard installation (stable torch):" diff --git a/requirements/utils/sync_metadata.py b/requirements/utils/sync_metadata.py new file mode 100755 index 0000000..3c14b10 --- /dev/null +++ b/requirements/utils/sync_metadata.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +"""Sync minimum version metadata from utils.py to pyproject.toml. + +This script reads the actual version constraints from +src/finetuning_scheduler/dynamic_versioning/utils.py and updates +the informational [tool.fts.min-versions] section in pyproject.toml +to match. + +This ensures the user-facing metadata stays synchronized with the +actual version constraints used during installation. +""" + +import re +import sys +from pathlib import Path + + +def extract_torch_version(utils_content: str) -> str: + """Extract torch version constraint from BASE_DEPENDENCIES. + + Args: + utils_content: Content of utils.py + + Returns: + Full torch version constraint (e.g., ">=2.6.0") + """ + # Look for: BASE_DEPENDENCIES = [..., "torch>=X.Y.Z", ...] + match = re.search(r'BASE_DEPENDENCIES\s*=\s*\[.*?"torch(>=?[0-9.]+(?:,[<>=0-9.]+)?)"', utils_content, re.DOTALL) + if not match: + raise ValueError("Could not find torch version in BASE_DEPENDENCIES") + return match.group(1) + + +def extract_lightning_version(utils_content: str) -> str: + """Extract lightning version constraint from LIGHTNING_VERSION. + + Args: + utils_content: Content of utils.py + + Returns: + Full lightning version constraint (e.g., ">=2.6.0,<2.6.1") + """ + # Look for: LIGHTNING_VERSION = ">=X.Y.Z,=?[0-9.]+(?:,[<>=0-9.]+)?)"', utils_content) + if not match: + raise ValueError("Could not find LIGHTNING_VERSION") + return match.group(1) + + +def extract_python_version(pyproject_content: str) -> str: + """Extract minimum Python version from pyproject.toml requires-python. + + Args: + pyproject_content: Content of pyproject.toml + + Returns: + Minimum python version (e.g., "3.10") + """ + # Look for: requires-python = ">=X.Y" + match = re.search(r'requires-python\s*=\s*">=([0-9.]+)"', pyproject_content) + if not match: + raise ValueError("Could not find requires-python in pyproject.toml") + return match.group(1) + + +def update_metadata_section(pyproject_content: str, torch_ver: str, lightning_ver: str, python_ver: str) -> str: + """Update the [tool.fts.min-versions] section with new versions. + + Args: + pyproject_content: Content of pyproject.toml + torch_ver: Minimum torch version + lightning_ver: Minimum lightning version + python_ver: Minimum python version + + Returns: + Updated pyproject.toml content + """ + # Use a line-by-line approach for safer replacement + lines = pyproject_content.split('\n') + updated_lines = [] + in_metadata_section = False + metadata_updated = False + + for i, line in enumerate(lines): + if '[tool.fts.min-versions]' in line: + in_metadata_section = True + updated_lines.append(line) + elif in_metadata_section: + # Update torch, lightning, python lines + if line.strip().startswith('torch'): + updated_lines.append(f'torch = "{torch_ver}"') + metadata_updated = True + elif line.strip().startswith('lightning'): + updated_lines.append(f'lightning = "{lightning_ver}"') + elif line.strip().startswith('python'): + updated_lines.append(f'python = "{python_ver}" # Defined in pyproject.toml requires-python') + in_metadata_section = False # End of section + else: + updated_lines.append(line) + else: + updated_lines.append(line) + + if not metadata_updated: + raise ValueError("Could not find [tool.fts.min-versions] section to update") + + return '\n'.join(updated_lines) + + +def main(): + """Main entry point.""" + # Locate files + script_dir = Path(__file__).parent + repo_root = script_dir.parent.parent + + utils_path = repo_root / "src" / "finetuning_scheduler" / "dynamic_versioning" / "utils.py" + pyproject_path = repo_root / "pyproject.toml" + + if not utils_path.exists(): + print(f"Error: Could not find {utils_path}", file=sys.stderr) + return 1 + + if not pyproject_path.exists(): + print(f"Error: Could not find {pyproject_path}", file=sys.stderr) + return 1 + + # Read files + utils_content = utils_path.read_text() + pyproject_content = pyproject_path.read_text() + + # Extract versions + try: + torch_ver = extract_torch_version(utils_content) + lightning_ver = extract_lightning_version(utils_content) + python_ver = extract_python_version(pyproject_content) + except ValueError as e: + print(f"Error extracting versions: {e}", file=sys.stderr) + return 1 + + print("Extracted versions from utils.py:") + print(f" torch: {torch_ver}") + print(f" lightning: {lightning_ver}") + print(f" python: {python_ver} (from pyproject.toml requires-python)") + + # Update pyproject.toml + try: + updated_content = update_metadata_section(pyproject_content, torch_ver, lightning_ver, python_ver) + except ValueError as e: + print(f"Error updating pyproject.toml: {e}", file=sys.stderr) + return 1 + + # Check if anything changed + if updated_content == pyproject_content: + print("✓ Metadata already up to date") + return 0 + + # Write updated content + pyproject_path.write_text(updated_content) + print(f"✓ Updated {pyproject_path}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/build_fts_env.sh b/scripts/build_fts_env.sh index 61efa4f..28481de 100755 --- a/scripts/build_fts_env.sh +++ b/scripts/build_fts_env.sh @@ -10,17 +10,14 @@ # ./build_fts_env.sh --repo_home=${HOME}/repos/finetuning-scheduler --target_env_name=fts_oldest --oldest # build release: # ./build_fts_env.sh --repo_home=${HOME}/repos/fts-release --target_env_name=fts_release -# build latest with torch test channel: -# ./build_fts_env.sh --repo_home=~/repos/finetuning-scheduler --target_env_name=fts_latest --torch_test_channel # build latest from a package from source: # ./build_fts_env.sh --repo_home=${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --from-source="lightning:${HOME}/repos/lightning:pytorch" # -# To use a specific PyTorch nightly, edit requirements/ci/torch-nightly.txt +# To configure PyTorch version (nightly/test/stable), edit requirements/ci/torch-pre.txt set -eo pipefail unset repo_home unset target_env_name -unset torch_test_channel unset uv_install_flags unset no_commit_pin unset venv_dir @@ -28,6 +25,7 @@ unset oldest declare -a from_source_specs=() SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" source "${SCRIPT_DIR}/infra_utils.sh" usage(){ @@ -36,7 +34,6 @@ Usage: $0 [ --repo_home input] [ --target_env_name input ] [ --oldest ] # Use oldest CI requirements (Python 3.10, requirements-oldest.txt) - [ --torch_test_channel ] # Use PyTorch test/RC channel [ --uv_install_flags "flags" ] [ --no_commit_pin ] [ --venv-dir input ] @@ -51,8 +48,6 @@ Usage: $0 # ./build_fts_env.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_oldest --oldest --venv-dir=/mnt/cache/\${USER}/.venvs # build release: # ./build_fts_env.sh --repo_home=\${HOME}/repos/fts-release --target_env_name=fts_release - # build latest with torch test channel: - # ./build_fts_env.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --torch_test_channel # build latest with no cache: # ./build_fts_env.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --uv_install_flags="--no-cache" # build latest without using CI commit pinning: @@ -60,14 +55,15 @@ Usage: $0 # build latest from Lightning source: # ./build_fts_env.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --from-source="lightning:\${HOME}/repos/lightning:pytorch" - # To use a specific PyTorch nightly, edit requirements/ci/torch-nightly.txt: - # Line 1: torch version (e.g., 2.10.0.dev20251124) + # To configure PyTorch version, edit requirements/ci/torch-pre.txt: + # Line 1: torch version (e.g., 2.10.0 for test, 2.10.0.dev20251124 for nightly) # Line 2: CUDA target (e.g., cu128) + # Line 3: channel type (test or nightly) EOF exit 1 } -args=$(getopt -o '' --long repo_home:,target_env_name:,oldest,torch_test_channel,uv_install_flags:,no_commit_pin,venv-dir:,from-source:,help -- "$@") +args=$(getopt -o '' --long repo_home:,target_env_name:,oldest,uv_install_flags:,no_commit_pin,venv-dir:,from-source:,help -- "$@") if [[ $? -gt 0 ]]; then usage fi @@ -79,7 +75,6 @@ do --repo_home) repo_home=$2 ; shift 2 ;; --target_env_name) target_env_name=$2 ; shift 2 ;; --oldest) oldest=1 ; shift ;; - --torch_test_channel) torch_test_channel=1 ; shift ;; --uv_install_flags) uv_install_flags=$2 ; shift 2 ;; --no_commit_pin) no_commit_pin=1 ; shift ;; --venv-dir) venv_dir=$2 ; shift 2 ;; @@ -109,32 +104,7 @@ if [[ ${#from_source_specs[@]} -gt 0 ]]; then from_source_spec=$(IFS=';'; echo "${from_source_specs[*]}") fi -# Read torch nightly configuration from requirements/ci/torch-nightly.txt -# Returns two values via global variables: TORCH_NIGHTLY_VERSION and TORCH_NIGHTLY_CUDA -read_torch_nightly_config() { - local nightly_file="${repo_home}/requirements/ci/torch-nightly.txt" - TORCH_NIGHTLY_VERSION="" - TORCH_NIGHTLY_CUDA="" - - if [[ -f "${nightly_file}" ]]; then - # Read non-comment, non-empty lines - local lines=() - while IFS= read -r line || [[ -n "$line" ]]; do - # Skip comments and empty lines - [[ "$line" =~ ^# ]] && continue - [[ -z "$line" ]] && continue - lines+=("$line") - done < "${nightly_file}" - - # First line is torch version, second is CUDA target - if [[ ${#lines[@]} -ge 1 ]]; then - TORCH_NIGHTLY_VERSION="${lines[0]}" - fi - if [[ ${#lines[@]} -ge 2 ]]; then - TORCH_NIGHTLY_CUDA="${lines[1]}" - fi - fi -} +# Read torch prerelease configuration (now handled by infra_utils.sh::read_torch_pre_config) clear_activate_env(){ local python_version=$1 @@ -159,33 +129,29 @@ base_env_build(){ clear_activate_env ${python_version} - # Check for torch nightly configuration (skip for oldest builds) + # Check for torch prerelease configuration (skip for oldest builds) if [[ -z ${oldest} ]]; then - read_torch_nightly_config + read_torch_pre_config fi # Handle PyTorch version selection (pre-install before FTS dependencies) - # Priority: oldest (stable from lock) > torch nightly from config > torch test channel > stable (via --torch-backend in fts_install) + # Priority: oldest (stable from lock) > torch prerelease from config > stable (via --torch-backend in fts_install) if [[ -n ${oldest} ]]; then # For oldest builds, torch is installed from requirements-oldest.txt (stable version) echo "Using torch stable from requirements-oldest.txt for oldest build" - elif [[ -n "${TORCH_NIGHTLY_VERSION}" ]]; then - # Nightly version from torch-nightly.txt with specified CUDA backend - local cuda_target="${TORCH_NIGHTLY_CUDA:-cu128}" # Default to cu128 if not specified - local torch_pkg="torch==${TORCH_NIGHTLY_VERSION}" - local torch_index_url="https://download.pytorch.org/whl/nightly/${cuda_target}" - echo "Pre-installing PyTorch nightly from torch-nightly.txt: ${torch_pkg}" + elif [[ -n "${TORCH_PRE_VERSION}" ]]; then + # Prerelease (nightly or test) configured in torch-pre.txt + local cuda_target="${TORCH_PRE_CUDA:-cu128}" # Default to cu128 if not specified + local torch_pkg="torch==${TORCH_PRE_VERSION}" + local torch_index_url=$(get_torch_index_url "${TORCH_PRE_CHANNEL}" "${cuda_target}") + + echo "Pre-installing PyTorch ${TORCH_PRE_CHANNEL} from torch-pre.txt: ${torch_pkg}" + echo " Channel: ${TORCH_PRE_CHANNEL}" echo " CUDA target: ${cuda_target}" echo " Index URL: ${torch_index_url}" + uv pip install ${uv_install_flags} --prerelease=allow "${torch_pkg}" --index-url "${torch_index_url}" - log_torch_version "after PyTorch nightly pre-install" - elif [[ -n ${torch_test_channel} ]]; then - # Test/RC channel - pre-install torch with test index and auto backend for GPU detection - local torch_index_url="https://download.pytorch.org/whl/test" - echo "Pre-installing PyTorch from test channel: ${torch_index_url}" - echo " Using --torch-backend=auto for GPU auto-detection" - uv pip install ${uv_install_flags} --prerelease=allow torch --index-url ${torch_index_url} --torch-backend=auto - log_torch_version "after PyTorch test channel pre-install" + log_torch_version "after PyTorch ${TORCH_PRE_CHANNEL} pre-install" fi # For stable builds, torch will be installed via FTS dependencies with --torch-backend=auto } @@ -218,16 +184,10 @@ fts_install(){ echo "Using oldest requirements file: ${req_file}" # Oldest builds use torch stable from lock file, need --torch-backend=auto torch_backend_flag="--torch-backend=auto" - elif [[ -n "${TORCH_NIGHTLY_VERSION}" || -n ${torch_test_channel} ]]; then - # Torch already pre-installed (nightly or test channel) - # When nightly: requirements.txt already has torch filtered during lock generation - # When test channel: filter at runtime - if [[ -n ${torch_test_channel} ]]; then - echo "Torch test channel pre-installed, filtering torch from requirements..." - grep -v '^torch==' "${req_file}" > /tmp/requirements_no_torch.txt - req_file="/tmp/requirements_no_torch.txt" - fi - echo "Using requirements without torch (pre-installed)" + elif [[ -n "${TORCH_PRE_VERSION}" ]]; then + # Torch prerelease already pre-installed (nightly or test channel) + # requirements.txt already has torch filtered during lock generation + echo "Using requirements without torch (pre-installed ${TORCH_PRE_CHANNEL})" else # Use auto torch backend for GPU detection torch_backend_flag="--torch-backend=auto" diff --git a/scripts/gen_fts_coverage.sh b/scripts/gen_fts_coverage.sh index 68b856f..1c14249 100755 --- a/scripts/gen_fts_coverage.sh +++ b/scripts/gen_fts_coverage.sh @@ -5,8 +5,6 @@ set -eo pipefail unset repo_home unset target_env_name -unset torch_dev_ver -unset torch_test_channel unset no_rebuild_base unset include_experimental unset uv_install_flags @@ -26,8 +24,6 @@ Usage: $0 [ --repo_home input] [ --target_env_name input ] [ --oldest ] # Use oldest CI requirements (Python 3.10, requirements-oldest.txt) - [ --torch_dev_ver input ] - [ --torch_test_channel ] [ --no_rebuild_base ] [ --no-special ] # Skip special tests (standalone/experimental), run only main test suite [ --include_experimental ] @@ -42,25 +38,20 @@ Usage: $0 # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --no_rebuild_base # generate oldest CI build coverage (matches CI oldest matrix): # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_oldest --oldest --no-special --venv-dir=/mnt/cache/\${USER}/.venvs - # generate fts_latest coverage with a given torch_dev_version: - # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --torch_dev_ver=dev20240201 - # generate fts_latest coverage, rebuilding base fts_latest with PyTorch test channel and run tests that require experimental patches: - # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --torch_test_channel --include_experimental - # generate fts_release coverage, rebuilding the base fts_release environment with PyTorch stable channel: + # generate fts_release coverage, rebuilding the base fts_release environment: # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/fts-release --target_env_name=fts_release # generate fts_release coverage, rebuilding the base fts_release environment with PyTorch test channel: - # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/fts-release --target_env_name=fts_release --torch_test_channel # generate fts_latest coverage with explicit venv directory (recommended for hardlink performance): # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --venv-dir=/mnt/cache/\${USER}/.venvs # generate fts_release coverage without using CI commit pinning: # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/fts-release --target_env_name=fts_release --no_commit_pin # dry-run mode: setup environment and show what tests would run without executing them: - # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --torch_dev_ver=dev20240201 --dry-run + # ./gen_fts_coverage.sh --repo_home=\${HOME}/repos/finetuning-scheduler --target_env_name=fts_latest --dry-run EOF exit 1 } -args=$(getopt -o '' --long repo_home:,target_env_name:,oldest,torch_dev_ver:,torch_test_channel,no_rebuild_base,no-special,include_experimental,uv_install_flags:,no_commit_pin,venv-dir:,from-source:,dry-run,help -- "$@") +args=$(getopt -o '' --long repo_home:,target_env_name:,oldest,no_rebuild_base,no-special,include_experimental,uv_install_flags:,no_commit_pin,venv-dir:,from-source:,dry-run,help -- "$@") if [[ $? -gt 0 ]]; then usage fi @@ -72,8 +63,6 @@ do --repo_home) repo_home=$2 ; shift 2 ;; --target_env_name) target_env_name=$2 ; shift 2 ;; --oldest) oldest=1 ; shift ;; - --torch_dev_ver) torch_dev_ver=$2 ; shift 2 ;; - --torch_test_channel) torch_test_channel=1 ; shift ;; --no_rebuild_base) no_rebuild_base=1 ; shift ;; --no-special) no_special=1 ; shift ;; --include_experimental) include_experimental=1 ; shift ;; @@ -159,18 +148,10 @@ env_rebuild(){ case $1 in fts_latest|fts_oldest) - if [[ -n ${torch_dev_ver} ]]; then - cmd_args+=("--torch_dev_ver=${torch_dev_ver}") - elif [[ $torch_test_channel -eq 1 ]]; then - cmd_args+=("--torch_test_channel") - fi log_msg "Final build command: ${cmd_args[*]}" "${cmd_args[@]}" ;; fts_release) - if [[ $torch_test_channel -eq 1 ]]; then - cmd_args+=("--torch_test_channel") - fi log_msg "Final build command: ${cmd_args[*]}" "${cmd_args[@]}" ;; diff --git a/scripts/infra_utils.sh b/scripts/infra_utils.sh index ede420f..a7c174a 100755 --- a/scripts/infra_utils.sh +++ b/scripts/infra_utils.sh @@ -248,3 +248,56 @@ install_from_source_packages(){ fi done } + +# Read torch prerelease configuration from requirements/ci/torch-pre.txt +# Returns values via global variables: +# TORCH_PRE_VERSION - torch version to install +# TORCH_PRE_CUDA - CUDA target for local builds +# TORCH_PRE_CHANNEL - channel type: "test" or "nightly" +# Note: Requires REPO_ROOT to be set (usually via SCRIPT_DIR) +read_torch_pre_config() { + TORCH_PRE_VERSION="" + TORCH_PRE_CUDA="" + TORCH_PRE_CHANNEL="" + + # Determine repo root from SCRIPT_DIR if not already set + local repo_root="${REPO_ROOT:-$(cd "${SCRIPT_DIR}/.." && pwd)}" + local pre_file="${repo_root}/requirements/ci/torch-pre.txt" + + if [[ ! -f "${pre_file}" ]]; then + return + fi + + # Read non-comment, non-empty lines + local lines=($(grep -v '^#' "${pre_file}" | grep -v '^$' || true)) + + if [[ ${#lines[@]} -ge 3 ]]; then + TORCH_PRE_VERSION="${lines[0]}" + TORCH_PRE_CUDA="${lines[1]}" + TORCH_PRE_CHANNEL="${lines[2]}" + + # Validate channel + if [[ "${TORCH_PRE_CHANNEL}" != "test" && "${TORCH_PRE_CHANNEL}" != "nightly" ]]; then + echo "ERROR: Invalid channel '${TORCH_PRE_CHANNEL}' in ${pre_file}" >&2 + echo "Must be 'test' or 'nightly'" >&2 + return 1 + fi + fi +} + +# Get torch index URL based on channel and CUDA target +# Args: $1 = channel ("test" or "nightly"), $2 = cuda_target (e.g., "cu128" or "cpu") +# Returns: PyTorch wheel index URL +get_torch_index_url() { + local channel="$1" + local cuda_target="${2:-cpu}" + + if [[ "${channel}" == "test" ]]; then + echo "https://download.pytorch.org/whl/test/${cuda_target}" + elif [[ "${channel}" == "nightly" ]]; then + echo "https://download.pytorch.org/whl/nightly/${cuda_target}" + else + echo "ERROR: Invalid channel: ${channel}" >&2 + return 1 + fi +} diff --git a/src/finetuning_scheduler/dynamic_versioning/utils.py b/src/finetuning_scheduler/dynamic_versioning/utils.py index abf03b7..4f08578 100644 --- a/src/finetuning_scheduler/dynamic_versioning/utils.py +++ b/src/finetuning_scheduler/dynamic_versioning/utils.py @@ -18,6 +18,12 @@ # ----------------------------------------------------------------------------- # Lightning Configuration # ----------------------------------------------------------------------------- +# +# These version constraints are the single source of truth for minimum versions. +# They are used by setup.py to generate dynamic dependencies at build time. +# +# For visibility, these values are also documented in pyproject.toml under +# [tool.fts.min-versions] (informational only - not used during installation). # Shared version constraint for all Lightning packages LIGHTNING_VERSION = ">=2.6.0,<2.6.1" @@ -43,6 +49,9 @@ # Base dependencies (torch + Lightning are handled dynamically) # These are the core dependencies that are always installed +# +# Note: For visibility, minimum versions are also documented in pyproject.toml +# under [tool.fts.min-versions] (informational only). BASE_DEPENDENCIES = [ "torch>=2.6.0", ] diff --git a/src/finetuning_scheduler/fts.py b/src/finetuning_scheduler/fts.py index 26b6202..f387307 100644 --- a/src/finetuning_scheduler/fts.py +++ b/src/finetuning_scheduler/fts.py @@ -9,12 +9,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -r""" -Fine-Tuning Scheduler -^^^^^^^^^^^^^^^^^^^^^ +r"""Fine-Tuning Scheduler. Used to implement flexible fine-tuning training schedules - """ import logging import os @@ -41,6 +38,7 @@ ScheduleImplMixin, ScheduleParsingMixin, STRATEGY_ADAPTERS, + _discover_strategy_adapters ) from finetuning_scheduler.strategy_adapters.base import StrategyAdapter from finetuning_scheduler.types import ParamGroupAddable @@ -111,7 +109,7 @@ def __init__( reinit_optim_cfg: Optional[Dict] = None, reinit_lr_cfg: Optional[Dict] = None, strategy_adapter_cfg: Optional[Dict] = None, - custom_strategy_adapter: Optional[Dict[str, str]] = None, + custom_strategy_adapters: Optional[Dict[str, str]] = None, allow_untested: bool = False, apply_lambdas_new_pgs: bool = False, logging_level: int = logging.INFO, @@ -223,11 +221,15 @@ def __init__( :external+pl:class:`~lightning.pytorch.strategies.Strategy`. See the relevant :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` documentation for strategy-specific configuration options. Defaults to None. - custom_strategy_adapter: A dictionary associating the canonical ``strategy_flag`` associated with a - :external+pl:class:`~lightning.pytorch.strategies.Strategy` (potentially a custom user-registered one) - to the fully qualified path of a - :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` subclass. This is an experimental - feature that is subject to change. Requires ``allow_untested`` to be set to ``True``. Defaults to None. + custom_strategy_adapters: A dictionary mapping PyTorch Lightning strategy flags (canonical strategy + names like ``"single_device"``, ``"auto"``, ``"ddp"``, etc.) to strategy adapter references. Multiple + ``strategy_flag`` keys can be associated with the same adapter. The adapter reference can be: (1) an + entry point name registered under ``finetuning_scheduler.strategy_adapters`` (see + :ref:`strategy_adapter_entry_points`) (2) a fully + qualified :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` subclass path in the + format ``"module.path:ClassName"`` or (3) a fully qualified dot path in the format + ``"module.path.ClassName"``. This is an experimental feature that is subject to change. + Defaults to None. apply_lambdas_new_pgs: If ``True``, applies most recent lambda in ``lr_lambdas`` list to newly added optimizer groups for lr schedulers that have a ``lr_lambdas`` attribute. Note this option only applies to phases without reinitialized lr schedulers. Phases with defined lr scheduler reinitialization configs @@ -270,7 +272,7 @@ def __init__( self.reinit_optim_cfg = reinit_optim_cfg self.reinit_lr_cfg = reinit_lr_cfg self.strategy_adapter_cfg = strategy_adapter_cfg or {} - self.custom_strategy_adapter = custom_strategy_adapter + self.custom_strategy_adapters = custom_strategy_adapters self.allow_untested = allow_untested self.apply_lambdas_new_pgs = apply_lambdas_new_pgs self.enforce_phase0_params = enforce_phase0_params @@ -527,6 +529,19 @@ def restore_best_ckpt(self) -> None: except KeyError as ke: # we may want to allow training to progress conditioned on context of restoration self._maybe_allow_incompatible_reinit_ckpt(ke) self.trainer._checkpoint_connector.restore_datamodule() + + # Inspect state before strategy adapter transformation + loaded_ckpt = self.trainer._checkpoint_connector._loaded_checkpoint + + # Allow strategy-specific adapters to transform the state dict before model restoration + if self.strategy_adapter is not None: + try: + loaded_ckpt = self.strategy_adapter.before_restore_model(loaded_ckpt) + # assign back in case adapter replaced or modified the checkpoint + self.trainer._checkpoint_connector._loaded_checkpoint = loaded_ckpt + except Exception as err: + rank_zero_warn(f"Strategy adapter before_restore_model hook raised: {err}") + self.trainer._checkpoint_connector.restore_model() # we need to override checkpoint_connector.restore_training_state() to bypass loop restoration # if additional customizations are required, may make sense to subclass _CheckpointConnector at some point @@ -656,6 +671,7 @@ def _strategy_setup(self, trainer: "pl.Trainer") -> None: connect_flg = getattr(trainer._accelerator_connector, "_strategy_flag", "") strategy_flag = getattr(connect_flg, "strategy_name", connect_flg.__class__.__name__.lower()) if \ isinstance(connect_flg, Strategy) else connect_flg + _discover_strategy_adapters() # Discover strategy adapter plugins lazily supported = [t.lower() for t in self._supported_strategy_flags()] if strategy_flag and strategy_flag not in supported: # type: ignore[attr-defined] if not self.allow_untested: @@ -671,8 +687,8 @@ def _strategy_setup(self, trainer: "pl.Trainer") -> None: f" '{strategy}' because ``allow_untested`` is ``True``." # type: ignore[attr-defined] ) rank_zero_warn(warn_msg) - if self.custom_strategy_adapter: - strategy_cls = self._import_strategy_adapter(strategy_flag, self.custom_strategy_adapter) + if self.custom_strategy_adapters: + strategy_cls = self._resolve_strategy_adapter(strategy_flag, self.custom_strategy_adapters) rank_zero_info( f"Imported custom strategy adapter class type `{strategy_cls}` associated with the current strategy" f" `{strategy_flag}`." @@ -703,13 +719,15 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s Raises: SystemExit: Gracefully exit before training if only generating and not executing a fine-tuning schedule. """ + # TODO: might not be necessary if we move introspection back to IT where it belongs + # Save pl_module/trainer refs early to allow strategy adapter selection to introspect the module + self.pl_module, self.trainer = pl_module, trainer # save pl_module/trainer refs for downstream convenience self._callback_dep_setup(trainer, pl_module, stage) self._strategy_setup(trainer) - self.pl_module, self.trainer = pl_module, trainer # save pl_module/trainer refs for downstream convenience if self.gen_ft_sched_only: if trainer.is_global_zero: assert self.log_dir is not None, "log_dir must be set to generate a fine-tuning schedule" - _ = ScheduleImplMixin.gen_ft_schedule(self.pl_module, self.log_dir) + _ = self.strategy_adapter.gen_ft_schedule(self.log_dir) log.info("Bypassing training, generating fine-tuning schedule for review and subsequent fine-tuning") raise SystemExit(0) if not self.epoch_transitions_only: diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py index 2991acd..9583032 100644 --- a/src/finetuning_scheduler/fts_supporters.py +++ b/src/finetuning_scheduler/fts_supporters.py @@ -67,6 +67,55 @@ STRATEGY_ADAPTERS = {"fsdp": FSDPStrategyAdapter, "modelparallelstrategy": ModelParallelStrategyAdapter} +def _discover_strategy_adapters() -> None: + """Discover strategy adapter plugins via entry points. + + This tries to discover user-contributed adapters registered under the + ``finetuning_scheduler.strategy_adapters`` entry point group and extends the + runtime ``STRATEGY_ADAPTERS`` mapping with any discovered adapters keyed by + the entry point name (lowercased). + + Note: + This function is called lazily during strategy setup. + + .. warning:: + This is an :ref:`experimental ` feature which is + still in development. The entry point API and plugin discovery mechanism may change in future + releases. See :ref:`strategy_adapter_entry_points` for documentation. + """ + # We now require Python 3.10+ environments so can use the standard importlib.metadata API. + from importlib.metadata import entry_points + eps = entry_points(group="finetuning_scheduler.strategy_adapters") + for ep in eps: + try: + # prefer using the standard entrypoint loader which handles 'module:attr' + try: + cls = ep.load() + except Exception as ep_err: + # fall back to dot notation if a colon-separated import path was not provided + try: + val = getattr(ep, 'value', '') + if not val: + raise ValueError(f"Entry point {ep.name} has no value attribute") + module_name, attr = val.split(':', 1) if ':' in val else val.rsplit('.', 1) + module = __import__(module_name, fromlist=[attr]) + cls = getattr(module, attr) + except Exception as fallback_err: + # If fallback also fails, log both errors and re-raise + rank_zero_warn( + f"Failed to load entry point {ep.name} with ep.load(): {ep_err}. " + f"Fallback import also failed: {fallback_err}" + ) + raise + if hasattr(ep, "name") and isinstance(ep.name, str): + STRATEGY_ADAPTERS[ep.name.lower()] = cls + rank_zero_info(f"Discovered strategy adapter entrypoint '{ep.name}' -> {cls}") + except Exception as err: + rank_zero_warn(f"Failed to load strategy adapter entry point {ep}: {err}") + +# Note: _discover_strategy_adapters() is called lazily during strategy setup + + @dataclass class FTSState: """Dataclass to encapsulate the :class:`~finetuning_scheduler.fts.FinetuningScheduler` internal state.""" @@ -533,6 +582,7 @@ class ScheduleParsingMixin(ABC): ft_schedule: Optional[Union[str, dict]] reinit_optim_cfg: Optional[Dict] reinit_lr_cfg: Optional[Dict] + strategy_adapter: StrategyAdapter # added to support adapter-based parameter naming def _validate_ft_sched(self) -> Tuple[int, int]: """Ensure the explicitly specified fine-tuning schedule has a valid configuration. @@ -546,7 +596,8 @@ def _validate_ft_sched(self) -> Tuple[int, int]: max_phase = 0 self._validate_schedule_keys() self._validate_reinit_cfg() - named_params = dict(self.pl_module.named_parameters()).keys() + # Use strategy adapter to get named params (allows TL-style or other custom naming) + named_params = self.strategy_adapter.get_named_params_for_schedule_validation().keys() model_shared_params = find_shared_parameters(self.pl_module) msp_ref = tuple((model_shared_params, set(itertools.chain(*model_shared_params)))) for depth in self.ft_schedule.keys(): # type: ignore[union-attr] @@ -1172,8 +1223,8 @@ def _import_reinit_class( return reinit_class @staticmethod - def _import_strategy_adapter(strategy_key: str, adapter_map: Dict[str, str]) -> Type[StrategyAdapter]: - """Import the custom strategy adapter specified in the ``custom_strategy_adapter`` configuration. + def _resolve_strategy_adapter(strategy_key: str, adapter_map: Dict[str, str]) -> Type[StrategyAdapter]: + """Resolve the custom strategy adapter specified in the ``custom_strategy_adapters`` configuration. Args: qualname (Dict): The user-provided custom strategy adapter fully qualified class name. @@ -1190,13 +1241,29 @@ def _import_strategy_adapter(strategy_key: str, adapter_map: Dict[str, str]) -> if not qualname: raise MisconfigurationException( f"Current strategy name ({strategy_key}) does not map to a custom strategy adapter in the" - f" provided `custom_strategy_adapter` mapping ({adapter_map})." + f" provided `custom_strategy_adapters` mapping ({adapter_map})." ) - class_module, class_name = qualname.rsplit(".", 1) + # If a short entry point name was provided, check the discovered STRATEGY_ADAPTERS mapping + if qualname in STRATEGY_ADAPTERS: + return STRATEGY_ADAPTERS[qualname] + + # Accept either 'module.Class' or 'module:Class' form to support both direct import paths and + # entry-point-like strings. + if ":" in qualname: + class_module, class_name = qualname.split(":", 1) + else: + # Require at least one '.' for module.Class format + if "." not in qualname: + raise ValueError( + f"Invalid adapter name '{qualname}'. Must be either a registered plugin name " + f"(found in STRATEGY_ADAPTERS: {list(STRATEGY_ADAPTERS.keys())}) or a fully qualified " + f"class name in 'module.Class' or 'module:Class' format." + ) + class_module, class_name = qualname.rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) custom_strategy_adapter_cls = getattr(module, class_name) issubclass(custom_strategy_adapter_cls, StrategyAdapter) - except (ImportError, AttributeError) as err: + except (ImportError, AttributeError, ValueError) as err: error_msg = ( "Could not import the specified custom strategy adapter class using the provided fully qualified class" f" name ({qualname}). Received the following error while importing: {err}. Please validate specified" @@ -1373,7 +1440,8 @@ def init_ft_sched(self) -> None: self.max_depth = len(self.ft_schedule) - 1 else: self.max_depth = min(self.max_depth, len(self.ft_schedule) - 1) - max_phase, max_epoch_wm = self._validate_ft_sched() # type: ignore[attr-defined] + # Delegate schedule validation to strategy adapter (allows custom validation logic) + max_phase, max_epoch_wm = self.strategy_adapter.validate_ft_sched() # if the final phase is not using EarlyStopping, apply the maximum phase-specified epoch to global max_epochs if self.ft_schedule[max_phase]["max_transition_epoch"] >= 0: assert self.trainer is not None @@ -1394,7 +1462,7 @@ def gen_implicit_schedule(self, sched_dir: Union[str, os.PathLike]) -> None: sched_dir: directory to which the generated schedule should be written. By default will be ``Trainer.log_dir``. """ - default_ft_schedule = ScheduleImplMixin.gen_ft_schedule(self.pl_module, sched_dir) + default_ft_schedule = self.strategy_adapter.gen_ft_schedule(sched_dir) assert default_ft_schedule is not None rank_zero_info(f"Generated default fine-tuning schedule '{default_ft_schedule}' for iterative fine-tuning") self.ft_schedule = self.load_yaml_schedule(default_ft_schedule) @@ -1429,6 +1497,11 @@ def save_schedule(schedule_name: str, layer_config: Dict, dump_loc: Union[str, o def gen_ft_schedule(module: Module, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: """Generate the default fine-tuning schedule using a naive, 2-parameters per-level heuristic. + .. deprecated:: 2.10.0 + Direct calls to this static method are deprecated. Use the + :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.gen_ft_schedule` instance method + instead, which allows strategy adapters to customize schedule generation. + Args: module (:class:`~torch.nn.Module`): The :class:`~torch.nn.Module` for which a fine-tuning schedule will be generated @@ -1438,11 +1511,33 @@ def gen_ft_schedule(module: Module, dump_loc: Union[str, os.PathLike]) -> Option :external+pl:class:`~lightning.pytorch.core.module.LightningModule` subclass in use with the suffix ``_ft_schedule.yaml``) """ + rank_zero_warn( + "Direct calls to ScheduleImplMixin.gen_ft_schedule() are deprecated since v2.10.0 and will be " + "removed in v2.12.0. Use strategy_adapter.gen_ft_schedule() instead to allow strategy-specific " + "customization." + ) + return ScheduleImplMixin._gen_ft_schedule_impl(module, dump_loc) + + @staticmethod + def _gen_ft_schedule_impl(module: Module, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + """Internal implementation of default fine-tuning schedule generation. + + This method contains the actual schedule generation logic shared between the deprecated static method + and the strategy adapter instance method. + + Args: + module (:class:`~torch.nn.Module`): The :class:`~torch.nn.Module` for which a fine-tuning schedule will be + generated + dump_loc: The directory to which the generated schedule (.yaml) should be written + + Returns: + os.PathLike: The path to the generated schedule + """ # Note: This initial default fine-tuning schedule generation approach is intentionally simple/naive but is - # effective for a suprising fraction of models. Future versions of this callback may use module introspection to - # generate default schedules that better accommodate more complex structures and specific architectures if the - # callback proves sufficiently useful. - log.info(f"Proceeding with dumping default fine-tuning schedule for {module.__class__.__name__}") + # effective for a surprising fraction of models. Future versions of this callback may use module + # introspection to generate default schedules that better accommodate more complex structures and + # specific architectures if the callback proves sufficiently useful. + rank_zero_info(f"Proceeding with dumping default fine-tuning schedule for {module.__class__.__name__}") param_lists: List = [] cur_group: List = [] model_params = list(module.named_parameters())[::-1] diff --git a/src/finetuning_scheduler/strategy_adapters/__init__.py b/src/finetuning_scheduler/strategy_adapters/__init__.py index fba81f2..56f031d 100644 --- a/src/finetuning_scheduler/strategy_adapters/__init__.py +++ b/src/finetuning_scheduler/strategy_adapters/__init__.py @@ -12,6 +12,49 @@ r""" Fine-Tuning Scheduler Strategy Adapters ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Strategy adapters extend Fine-Tuning Scheduler support for complex or custom training strategies. +The built-in adapters (:class:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter`, +:class:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter`) handle PyTorch's +advanced distributed training strategies. + +Plugin Support +************** + +.. warning:: + This is an :ref:`experimental ` feature which is + still in development. + +Fine-Tuning Scheduler supports custom strategy adapters via Python entry points. Third-party packages +can register custom strategy adapters that will be automatically discovered at runtime. + +To register a custom strategy adapter, add an entry point in your package's ``pyproject.toml``: + +.. code-block:: toml + + [project.entry-points."finetuning_scheduler.strategy_adapters"] + my_adapter = "my_package.adapters:MyStrategyAdapter" + +The entry point name (``my_adapter`` in this example) will be used to reference the adapter, +automatically lowercased. Once registered, the adapter can be used by mapping Lightning strategy +flags to the adapter via the ``custom_strategy_adapters`` parameter. You can use the entry point +name, a fully qualified class path with colon separator (``module:Class``), or dot separator +(``module.Class``): + +.. code-block:: python + + from finetuning_scheduler import FinetuningScheduler + + # Map strategy flags to adapters using entry point name + fts = FinetuningScheduler( + custom_strategy_adapters={ + "single_device": "my_adapter", # Entry point name + "ddp": "my_package.adapters:MyStrategyAdapter", # Colon-separated + "fsdp": "my_package.adapters.MyStrategyAdapter", # Dot-separated + } + ) + +See :ref:`strategy_adapter_entry_points` for complete documentation and examples. """ from finetuning_scheduler.strategy_adapters.base import StrategyAdapter from finetuning_scheduler.strategy_adapters.fsdp import FSDPStrategyAdapter diff --git a/src/finetuning_scheduler/strategy_adapters/base.py b/src/finetuning_scheduler/strategy_adapters/base.py index 8496183..b8ddd96 100644 --- a/src/finetuning_scheduler/strategy_adapters/base.py +++ b/src/finetuning_scheduler/strategy_adapters/base.py @@ -18,7 +18,9 @@ """ from functools import partialmethod from pprint import pformat as pfmt -from typing import Callable, Iterable, List, Optional, Tuple, Dict, Union +from typing import Callable, Iterable, List, Optional, Tuple, Dict, Union, Any +import logging +import os import torch from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -29,6 +31,8 @@ from lightning.pytorch.strategies.strategy import Strategy from lightning.pytorch.utilities.rank_zero import rank_zero_debug +log = logging.getLogger(__name__) + class StrategyAdapter: r"""Base class for all strategy adapters. Implements the default @@ -37,11 +41,6 @@ class StrategyAdapter: :external+pl:class:`~lightning.pytorch.strategies.Strategy` via an associated :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter`. - .. warning:: - - :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` is in BETA and subject to change. The interface - can bring breaking changes and new features with the next release of FTS. - .. tip:: If you want to extend FTS to use a custom, currently unsupported strategy or override current FTS behavior in @@ -165,6 +164,20 @@ def fts_optim_transform(self, orig_pl: List, inspect_only: bool = False) -> List """ return orig_pl + def before_restore_model(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + """Adapter hook executed before model restore. + + Strategy adapters can override this to modify or translate the checkpoint contents (e.g. for state-dict + translations) before the model's load path is executed. + + Args: + checkpoint (Dict[str, Any]): The full checkpoint dict loaded by the Trainer. + + Returns: + Dict[str, Any]: The checkpoint dictionary to be used for restore. + """ + return checkpoint + def logical_param_translation(self, param_names: List) -> List: """Effectively the reverse transformation of :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.fts_optim_transform`. Can be overridden by a @@ -380,3 +393,90 @@ def _get_target_bn_modules(self, schedule_phase: int) -> List: isinstance(m, torch.nn.modules.batchnorm._BatchNorm)] fts_optim_inspect = partialmethod(fts_optim_transform, inspect_only=True) + + def get_named_params_for_schedule_validation(self) -> Dict[str, torch.nn.Parameter]: + """Get named parameters for schedule validation. + + This method can be overridden by :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` + subclasses to customize parameter iteration for schedule validation (e.g., returning TL-style + parameter names instead of canonical names). + + .. note:: + Strategy adapters can override validation behavior at two levels of abstraction: + + 1. **Parameter naming only** (simpler): Override this method to provide custom parameter names + while using the default validation logic from + :meth:`~finetuning_scheduler.fts_supporters.ScheduleParsingMixin._validate_ft_sched`. + + 2. **Full validation logic** (more control): Override + :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.validate_ft_sched` to + completely customize the validation process. + + Choose the approach that best suits your use case. Most adapters only need to override this + method to provide custom parameter names. + + Returns: + Dict[str, torch.nn.Parameter]: A dictionary mapping parameter names to parameter tensors. + By default, returns the standard ``named_parameters()`` dict. + """ + return dict(self.pl_module.named_parameters()) + + def validate_ft_sched(self) -> Tuple[int, int]: + """Validate the fine-tuning schedule configuration. + + This method can be overridden by :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` + subclasses to customize schedule validation for specific strategies (e.g., strategies that + require substantially different validation logic beyond just custom parameter naming). + + .. note:: + Strategy adapters can override validation behavior at two levels of abstraction: + + 1. **Parameter naming only** (simpler): Override + :meth:`~finetuning_scheduler.strategy_adapters.StrategyAdapter.get_named_params_for_schedule_validation` + to provide custom parameter names while using the default validation logic from + :meth:`~finetuning_scheduler.fts_supporters.ScheduleParsingMixin._validate_ft_sched`. + + 2. **Full validation logic** (more control): Override this method to completely customize + the validation process. + + Choose the approach that best suits your use case. Most adapters only need to override + :meth:`get_named_params_for_schedule_validation` to provide custom parameter names. + + Returns: + Tuple[int, int]: A tuple of ints specifying: + 1. The depth of the final scheduled phase + 2. The maximum epoch watermark explicitly specified in the schedule + """ + # Import here to avoid circular dependency + from finetuning_scheduler.fts_supporters import ScheduleParsingMixin + from typing import cast + + rank_zero_debug( + f"[base StrategyAdapter.validate_ft_sched] Validating schedule for " + f"{self.pl_module.__class__.__name__}" + ) + # Delegate to the mixin's implementation by default. + return ScheduleParsingMixin._validate_ft_sched(cast(ScheduleParsingMixin, self.fts_handle)) + + def gen_ft_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + """Generate the default fine-tuning schedule using a naive, 2-parameters per-level heuristic. + + This method can be overridden by :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` + subclasses to customize schedule generation for specific strategies (e.g., using strategy-specific + parameter naming conventions). + + Args: + dump_loc: The directory to which the generated schedule (.yaml) should be written + + Returns: + os.PathLike: The path to the generated schedule, by default ``Trainer.log_dir`` and named after the + :external+pl:class:`~lightning.pytorch.core.module.LightningModule` subclass in use with the suffix + ``_ft_schedule.yaml``) + """ + # Import here to avoid circular dependency + from finetuning_scheduler.fts_supporters import ScheduleImplMixin + + rank_zero_debug( + f"[base StrategyAdapter.gen_ft_schedule] Generating schedule for {self.pl_module.__class__.__name__}" + ) + return ScheduleImplMixin._gen_ft_schedule_impl(self.pl_module, dump_loc) diff --git a/src/finetuning_scheduler/strategy_adapters/model_parallel.py b/src/finetuning_scheduler/strategy_adapters/model_parallel.py index 0b5efb2..a9598c3 100644 --- a/src/finetuning_scheduler/strategy_adapters/model_parallel.py +++ b/src/finetuning_scheduler/strategy_adapters/model_parallel.py @@ -80,10 +80,6 @@ class ModelParallelStrategyAdapter(StrategyAdapter): See the :ref:`model-parallel-fine-tuning-examples` tutorial for a concrete example and additional guidance. - .. warning:: - :class:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter` is in BETA and subject to change. - The interface can bring breaking changes and new features with the next release of PyTorch. - .. note:: ``fsdp_plan`` module name/pattern-based ``fully_shard`` directives are applied after any preceding Tensor Parallel or explicit ``fully_shard`` directives in ``LightningModule.configure_model``. FTS will only apply diff --git a/src/fts_examples/profiling/memprofiler.py b/src/fts_examples/profiling/memprofiler.py index 8ce1866..2e4e722 100644 --- a/src/fts_examples/profiling/memprofiler.py +++ b/src/fts_examples/profiling/memprofiler.py @@ -49,10 +49,6 @@ class MemProfiler: - `cuda memory snapshot and allocator history tracking `_ - host-level memory tracking - custom memory hooks (e.g. for activation checkpoint memory tracking via ``saved_tensors_hooks`` etc.) - - .. warning:: - :class:`~fts_examples.profiling.memprofiler.MemProfiler` is in BETA and subject to change. - The interface can bring breaking changes and new features with the next release of Finetuning Scheduler. """ def __init__(self, *args, **kwargs) -> None: """The MemProfiler is a powerful memory profiling utility that synthesizes numerous complementary profiling diff --git a/tests/helpers/expected_warns.py b/tests/helpers/expected_warns.py index 4d68190..d4666fa 100644 --- a/tests/helpers/expected_warns.py +++ b/tests/helpers/expected_warns.py @@ -22,7 +22,7 @@ "`max_epochs` was not", "The dirpath has changed from", "unless they are explicitly allowlisted", # required for oldest pytorch (2.5.0) with Lightning 2.5.6 - "Conversion of an array with ndim > 0", # still needed with python 3.9 and torch 2.4.0 + "Conversion of an array with ndim > 0", "Please use the new API settings to control TF32 behavior", # TODO: temporarily required with 20250811 nightly "treespec, LeafSpec", # TODO: required temporarily while lightning uses deprecated PT pytree API "torch.jit.script", # TODO: required temporarily with PT 2.10 nightly 20251124 due to upstream import diff --git a/tests/helpers/fake_adapter.py b/tests/helpers/fake_adapter.py new file mode 100644 index 0000000..aca24ce --- /dev/null +++ b/tests/helpers/fake_adapter.py @@ -0,0 +1,6 @@ +class FakeAdapter: + """A simple fake Strategy Adapter class for entry point discovery tests. + + Exists solely so entry-points pointing into the package can be imported in unit tests. + """ + pass diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py index 4435cca..0938bac 100644 --- a/tests/test_finetuning_scheduler_callback.py +++ b/tests/test_finetuning_scheduler_callback.py @@ -327,7 +327,7 @@ def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None: if self.mock_strategy: trainer._accelerator_connector._strategy_flag = MOCK_STRATEGY_MAPPING[self.mock_strategy][0] self.allow_untested = MOCK_STRATEGY_MAPPING[self.mock_strategy][1] - self.custom_strategy_adapter = MOCK_STRATEGY_MAPPING[self.mock_strategy][2] + self.custom_strategy_adapters = MOCK_STRATEGY_MAPPING[self.mock_strategy][2] super().setup(trainer, pl_module, stage) if self.mock_strategy and self.allow_untested: raise SystemExit(0) @@ -1111,6 +1111,28 @@ def test_fts_gen_ft_schedule(tmpdir, model: "LightningModule", dist_mode: bool, assert test_schedule[1]["params"] == expected[1] assert test_schedule[next(reversed(list(test_schedule.keys())))]["params"] == expected[2] + +def test_fts_gen_ft_schedule_deprecation_warning(tmpdir): + """Validate that direct calls to ScheduleImplMixin.gen_ft_schedule() issue a deprecation warning.""" + from finetuning_scheduler.fts_supporters import ScheduleImplMixin + + model = FinetuningSchedulerBoringModel() + + # Test that the deprecation warning is issued + with pytest.warns( + UserWarning, + match=r"Direct calls to ScheduleImplMixin\.gen_ft_schedule\(\) are deprecated since v2\.10\.0", + ): + schedule_path = ScheduleImplMixin.gen_ft_schedule(model, tmpdir) + + # Verify the schedule was still generated correctly + assert schedule_path is not None + assert os.path.isfile(schedule_path) + with open(schedule_path) as f: + test_schedule = yaml.safe_load(f.read()) + assert isinstance(test_schedule, Dict) + assert len(test_schedule) > 0 + @pytest.mark.skipif(not _MLFLOW_AVAILABLE, reason="test requires MLflow") @pytest.mark.parametrize("use_fts_log_dir", [True, False], ids=["fts_log_dir", "no_fts_log_dir"]) def test_fts_log_dir(tmpdir, use_fts_log_dir): @@ -2028,6 +2050,43 @@ def test_fts_unallowed_key_error(): test_fts.restore_best_ckpt() +def test_fts_strategy_adapter_restore_exception(): + """Test that exceptions raised by strategy adapter's before_restore_model hook are properly caught and + warned.""" + test_fts = FinetuningScheduler() + test_fts.pl_module, test_fts.trainer = mock.MagicMock(), mock.MagicMock() + + # Mock the checkpoint connector and its loaded checkpoint + mock_checkpoint = {"state_dict": {}, "optimizer_states": [{}]} + test_fts.trainer._checkpoint_connector._loaded_checkpoint = mock_checkpoint + test_fts.trainer._checkpoint_connector.resume_start = mock.MagicMock() + test_fts.trainer._checkpoint_connector.restore_datamodule = mock.MagicMock() + test_fts.trainer._checkpoint_connector.restore_model = mock.MagicMock() + test_fts.trainer._checkpoint_connector.resume_end = mock.MagicMock() + test_fts.trainer.strategy.barrier = mock.MagicMock() + test_fts.trainer.checkpoint_callback = mock.MagicMock() + test_fts.trainer.checkpoint_callback.best_model_path = "/mock/path/checkpoint.ckpt" + test_fts.trainer.optimizers = [mock.MagicMock()] + test_fts._fts_state._fts_ckpt_metadata = {"best_ckpt_pgs": [[]]} + + # Mock the strategy adapter to raise an exception in before_restore_model + mock_adapter = mock.MagicMock() + test_error_message = "Mock adapter transformation error" + mock_adapter.before_restore_model = mock.MagicMock(side_effect=RuntimeError(test_error_message)) + mock_adapter.using_sharded_optimizer = False + test_fts.strategy_adapter = mock_adapter + + # Mock _restore_training_state to prevent further execution + test_fts._restore_training_state = mock.MagicMock() + + # Execute and verify warning is issued + with pytest.warns(UserWarning, match=f"Strategy adapter before_restore_model hook raised.*{test_error_message}"): + test_fts.restore_best_ckpt() + + # Verify that restore_model was still called despite the exception + test_fts.trainer._checkpoint_connector.restore_model.assert_called_once() + + @pytest.mark.parametrize( "explicit_mode, lam_mode, w_expected", [ diff --git a/tests/test_strategy_adapter_discovery.py b/tests/test_strategy_adapter_discovery.py new file mode 100644 index 0000000..5fab1ae --- /dev/null +++ b/tests/test_strategy_adapter_discovery.py @@ -0,0 +1,165 @@ +from importlib.metadata import EntryPoint +from unittest.mock import MagicMock +import pytest + +from finetuning_scheduler import fts_supporters +from finetuning_scheduler.fts import FinetuningScheduler +from lightning.pytorch.utilities.exceptions import MisconfigurationException + + +def test_discover_strategy_adapters(monkeypatch): + # Create fake entrypoints that point to a plugin adapter using two common formats + # Use a local fake adapter class for discovery tests to avoid importing external packages + ep_colon = EntryPoint( + name="fake_adapter_colon", + value="tests.helpers.fake_adapter:FakeAdapter", + group="finetuning_scheduler.strategy_adapters", + ) + ep_dot = EntryPoint( + name="fake_adapter_dot", + value="tests.helpers.fake_adapter.FakeAdapter", + group="finetuning_scheduler.strategy_adapters", + ) + + def fake_entry_points(group=None): + if group == "finetuning_scheduler.strategy_adapters": + return [ep_colon, ep_dot] + return [] + + # Override the entry_points function in the fts_supporters module. + # Monkeypatch the importlib.metadata.entry_points function so our fake entrypoints are used by discovery + monkeypatch.setattr("importlib.metadata.entry_points", fake_entry_points) + # call discovery explicitly (no reload, since entry_points function is patched) + # call the function again - it will load entrypoints and register + fts_supporters._discover_strategy_adapters() + assert "fake_adapter_colon" in fts_supporters.STRATEGY_ADAPTERS + assert "fake_adapter_dot" in fts_supporters.STRATEGY_ADAPTERS + + +def test_resolve_strategy_adapter_by_qualname(): + adapter_map = {"single_device": "tests.helpers.fake_adapter:FakeAdapter"} + fts = FinetuningScheduler() + cls = fts._resolve_strategy_adapter("single_device", adapter_map) + assert cls.__name__ == "FakeAdapter" + + +def test_resolve_strategy_adapter_by_dot_form(): + adapter_map = {"single_device": "tests.helpers.fake_adapter.FakeAdapter"} + fts = FinetuningScheduler() + cls = fts._resolve_strategy_adapter("single_device", adapter_map) + assert cls.__name__ == "FakeAdapter" + + +def test_resolve_strategy_adapter_by_plugin_name(monkeypatch): + """Test importing strategy adapter using discovered plugin entry point name.""" + # register alias in STRATEGY_ADAPTERS to simulate a discovered plugin (monkeypatched so it's reverted) + monkeypatch.setitem( + fts_supporters.STRATEGY_ADAPTERS, "fakeplugin", fts_supporters.STRATEGY_ADAPTERS.get("fsdp") + ) + adapter_map = {"single_device": "fakeplugin"} + fts = FinetuningScheduler() + cls = fts._resolve_strategy_adapter("single_device", adapter_map) + assert cls in fts_supporters.STRATEGY_ADAPTERS.values() + + +def test_discover_strategy_adapters_ep_load_failure(monkeypatch): + """Test that ep.load() failure triggers fallback import and succeeds.""" + # Create a mock entry point that will fail on ep.load() but succeed with fallback + ep = MagicMock() + ep.name = "fallback_adapter" + ep.value = "tests.helpers.fake_adapter:FakeAdapter" + ep.load.side_effect = ImportError("Simulated ep.load() failure") + + def fake_entry_points(group=None): + if group == "finetuning_scheduler.strategy_adapters": + return [ep] + return [] + + monkeypatch.setattr("importlib.metadata.entry_points", fake_entry_points) + + # Should succeed via fallback import despite ep.load() failing + fts_supporters._discover_strategy_adapters() + assert "fallback_adapter" in fts_supporters.STRATEGY_ADAPTERS + + +def test_discover_strategy_adapters_both_imports_fail(monkeypatch): + """Test that both ep.load() and fallback import failures are handled gracefully.""" + # Create a mock entry point with invalid value that will fail both import methods + ep = MagicMock() + ep.name = "broken_adapter" + ep.value = "nonexistent.module:NonexistentClass" + ep.load.side_effect = ImportError("Simulated ep.load() failure") + + def fake_entry_points(group=None): + if group == "finetuning_scheduler.strategy_adapters": + return [ep] + return [] + + monkeypatch.setattr("importlib.metadata.entry_points", fake_entry_points) + + # Should not raise, just log warnings and continue + with pytest.warns(UserWarning, match="Failed to load strategy adapter entry point"): + fts_supporters._discover_strategy_adapters() + + # Adapter should not be registered + assert "broken_adapter" not in fts_supporters.STRATEGY_ADAPTERS + + +def test_discover_strategy_adapters_no_value_attribute(monkeypatch): + """Test handling of entry point with missing value attribute.""" + # Create a mock entry point without a value attribute + ep = MagicMock() + ep.name = "no_value_adapter" + ep.load.side_effect = AttributeError("No value attribute") + # Simulate missing value attribute + type(ep).value = property(lambda self: None) + + def fake_entry_points(group=None): + if group == "finetuning_scheduler.strategy_adapters": + return [ep] + return [] + + monkeypatch.setattr("importlib.metadata.entry_points", fake_entry_points) + + # Should handle gracefully and log warning + with pytest.warns(UserWarning, match="Failed to load strategy adapter entry point"): + fts_supporters._discover_strategy_adapters() + + assert "no_value_adapter" not in fts_supporters.STRATEGY_ADAPTERS + + +def test_resolve_strategy_adapter_invalid_format(): + """Test that invalid adapter name format raises MisconfigurationException.""" + # Test with a name that has no dots or colons (invalid format) + adapter_map = {"single_device": "invalidname"} + fts = FinetuningScheduler() + + with pytest.raises( + MisconfigurationException, + match=r"Invalid adapter name 'invalidname'.*Must be either a registered plugin name.*or a fully qualified", + ): + fts._resolve_strategy_adapter("single_device", adapter_map) + + +def test_resolve_strategy_adapter_import_error(): + """Test that import errors are properly handled and re-raised as MisconfigurationException.""" + adapter_map = {"single_device": "nonexistent.module:NonexistentClass"} + fts = FinetuningScheduler() + + with pytest.raises( + MisconfigurationException, + match=r"Could not import the specified custom strategy adapter class.*nonexistent\.module:NonexistentClass", + ): + fts._resolve_strategy_adapter("single_device", adapter_map) + + +def test_resolve_strategy_adapter_missing_strategy_key(): + """Test that missing strategy key in adapter_map raises MisconfigurationException.""" + adapter_map = {"ddp": "some_adapter"} # Missing "single_device" key + fts = FinetuningScheduler() + + with pytest.raises( + MisconfigurationException, + match=r"Current strategy name \(single_device\) does not map to a custom strategy adapter", + ): + fts._resolve_strategy_adapter("single_device", adapter_map)