diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml new file mode 100644 index 0000000..ddcb62e --- /dev/null +++ b/.github/workflows/copilot-setup-steps.yml @@ -0,0 +1,45 @@ +name: "Copilot Setup Steps" + +on: + workflow_dispatch: + push: + paths: + - .github/workflows/copilot-setup-steps.yml + pull_request: + paths: + - .github/workflows/copilot-setup-steps.yml + +jobs: + copilot-setup-steps: + runs-on: ubuntu-latest + environment: copilot + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install CI dependencies + uses: ./.github/actions/install-ci-dependencies + with: + python_version: '3.12' + show_pip_list: 'true' + + - name: Install pre-commit + shell: bash + run: | + pre-commit install + + - name: Run type checking + shell: bash + run: | + pyright -p pyproject.toml + + - name: Run pre-commit hooks + shell: bash + run: | + pre-commit run --all-files + + - name: Run Copilot setup steps + run: echo "Copilot setup steps completed successfully." diff --git a/pyproject.toml b/pyproject.toml index 67ecf13..d8282e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,33 +179,9 @@ known-first-party = ["docs","fts_examples","finetuning_scheduler","tests"] force-sort-within-sections = false order-by-type = false -[tool.mypy] -files = ["src/finetuning_scheduler"] -disallow_untyped_defs = "True" -ignore_missing_imports = "True" -show_error_codes = "True" -warn_redundant_casts = "True" -warn_unused_configs = "True" -warn_unused_ignores = "False" -allow_redefinition = "True" -# disable this rule as the PL Trainer attributes are defined in the connectors, not in its __init__ -disable_error_code = "attr-defined" -# style choices -warn_no_return = "False" -exclude = ['tests/.*'] - -# Ignore mypy errors for these files -# TODO: the goal is for this to be empty -#[[tool.mypy.overrides]] -# the list can be generated with: -# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",' -# module = [] -# ignore_errors = "True" - [tool.pyright] -# Using "basic" mode for initial migration from mypy -# TODO: Progressively tighten to "standard" as type issues are resolved -typeCheckingMode = "basic" +# Using "standard" mode for comprehensive type checking +typeCheckingMode = "standard" include = ["src/finetuning_scheduler"] exclude = [ "tests", @@ -214,22 +190,21 @@ exclude = [ "dist", ".git", ] -# Match mypy's ignore_missing_imports behavior +# Ignore missing imports from third-party libraries without type stubs reportMissingImports = "none" -# Match mypy's disable_error_code = "attr-defined" for PL Trainer attributes +# Disable attribute access checks - PL Trainer attributes are defined in connectors, not __init__ reportAttributeAccessIssue = "none" -# Match mypy's allow_redefinition behavior +# Allow constant redefinition for dynamic configuration patterns reportConstantRedefinition = "none" # Disable private import usage warnings - required for Lightning internals reportPrivateImportUsage = "none" -# Disable general type issues for generator context managers until properly annotated -reportGeneralTypeIssues = "warning" +# Enable comprehensive type checking for standard mode +reportGeneralTypeIssues = "error" +reportAssignmentType = "error" +reportCallIssue = "error" +reportIndexIssue = "error" # Disable invalid type form warnings for dynamic type expressions reportInvalidTypeForm = "none" -# Temporarily disable these until code is updated with proper type annotations -reportAssignmentType = "none" -reportCallIssue = "none" -reportIndexIssue = "none" [tool.coverage.report] exclude_lines = [ diff --git a/requirements/ci/requirements-oldest.txt b/requirements/ci/requirements-oldest.txt index 94a035f..20e13f0 100644 --- a/requirements/ci/requirements-oldest.txt +++ b/requirements/ci/requirements-oldest.txt @@ -125,7 +125,7 @@ evaluate==0.3.0 # via finetuning-scheduler (pyproject.toml) exceptiongroup==1.3.1 ; python_full_version < '3.11' # via anyio -fastapi==0.123.10 +fastapi==0.124.0 # via mlflow-skinny filelock==3.20.0 # via @@ -474,7 +474,7 @@ propcache==0.4.1 # via # aiohttp # yarl -protobuf==6.33.1 +protobuf==6.33.2 # via # databricks-sdk # mlflow-skinny diff --git a/requirements/ci/requirements.txt b/requirements/ci/requirements.txt index ae44b8a..6a0e957 100644 --- a/requirements/ci/requirements.txt +++ b/requirements/ci/requirements.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile /home/speediedan/repos/finetuning-scheduler/pyproject.toml --extra all --group dev --group test --output-file /home/speediedan/repos/finetuning-scheduler/requirements/ci/requirements.txt --no-strip-extras --resolution highest --universal --python-version 3.10 --prerelease=if-necessary-or-explicit --override /tmp/tmp.bOqxnBiEpd --index-strategy unsafe-best-match --no-emit-package torch +# uv pip compile /home/speediedan/repos/finetuning-scheduler/pyproject.toml --extra all --group dev --group test --output-file /home/speediedan/repos/finetuning-scheduler/requirements/ci/requirements.txt --no-strip-extras --resolution highest --universal --python-version 3.10 --prerelease=if-necessary-or-explicit --override /tmp/tmp.CNVCWjH0nH --index-strategy unsafe-best-match --no-emit-package torch aiohappyeyeballs==2.6.1 # via aiohttp aiohttp==3.13.2 @@ -137,7 +137,7 @@ exceptiongroup==1.3.1 ; python_full_version < '3.11' # pytest executing==2.2.1 # via stack-data -fastapi==0.123.10 +fastapi==0.124.0 # via mlflow-skinny fastjsonschema==2.21.2 # via nbformat @@ -381,8 +381,6 @@ more-itertools==10.8.0 ; platform_machine != 'ppc64le' and platform_machine != ' # via # jaraco-classes # jaraco-functools -mpmath==1.3.0 - # via sympy multidict==6.7.0 # via # aiohttp @@ -411,10 +409,6 @@ nbval==0.11.0 # via finetuning-scheduler (pyproject.toml) nest-asyncio==1.6.0 # via ipykernel -networkx==3.4.2 ; python_full_version < '3.11' - # via torch -networkx==3.6 ; python_full_version >= '3.11' - # via torch nh3==0.3.2 # via readme-renderer nodeenv==1.9.1 @@ -534,7 +528,7 @@ propcache==0.4.1 # via # aiohttp # yarl -protobuf==6.33.1 +protobuf==6.33.2 # via # databricks-sdk # mlflow-skinny @@ -725,8 +719,6 @@ stack-data==0.6.3 # via ipython starlette==0.50.0 # via fastapi -sympy==1.14.0 - # via torch tabulate==0.9.0 # via finetuning-scheduler (pyproject.toml) tensorboardx==2.6.4 diff --git a/requirements/utils/lock_ci_requirements.sh b/requirements/utils/lock_ci_requirements.sh index 7243a63..e6ad2fa 100755 --- a/requirements/utils/lock_ci_requirements.sh +++ b/requirements/utils/lock_ci_requirements.sh @@ -14,6 +14,7 @@ # - Lock file is generated with torch pinned to the nightly version # - Uses PyTorch nightly index for resolution # - Docker image and CI both use the same nightly version +# - Post-processing prunes torch-only dependencies (see prune_torch_only_deps) # - Without torch-nightly.txt: # - Uses stable torch from PyPI # - CI uses --torch-backend=cpu for CPU variant @@ -67,6 +68,15 @@ get_torch_nightly_version() { # Check if torch nightly is configured TORCH_NIGHTLY_VERSION=$(get_torch_nightly_version) +# Prune packages that are ONLY dependencies of torch from the lockfile. +# This reduces the dependency confusion attack surface when using unsafe-best-match +# by removing any packages that could potentially be resolved from the nightly index only. +# See prune_torch_deps.py for detailed documentation and implementation. +prune_torch_only_deps() { + local lockfile=$1 + python "${SCRIPT_DIR}/prune_torch_deps.py" "${lockfile}" +} + # Generate/update torch_override.txt if nightly is configured, remove if not generate_torch_override() { if [[ -n "${TORCH_NIGHTLY_VERSION}" ]]; then @@ -125,16 +135,25 @@ generate_lockfile() { # 1. Create a temporary override file to pin torch to the nightly version for dependency resolution # 2. Use --prerelease=if-necessary-or-explicit to only allow prereleases for explicitly specified packages (torch) # or where all versions of the package are pre-release - # 3. Use --extra-index-url with nightly CPU index for torch resolution + # 3. Use --index with nightly CPU index for torch resolution # 4. Use --index-strategy=unsafe-best-match for lockfile GENERATION only # This is required because with first-index (default), uv would either: # - Find torch on PyPI first (no nightly version), or # - Find scipy/etc on nightly index first (missing versions) - # Security impact is minimal because: - # - Lockfile generation runs on maintainer machines, not user machines - # - Generated lockfile pins exact package versions from PyPI - # - User INSTALLATION uses secure two-step approach (no unsafe-best-match) + # + # Security rationale for using unsafe-best-match during lockfile generation: + # a) User INSTALLATION uses a secure two-step approach, only ever installing torch nightly + # from the explicitly specified nightly index (no unsafe-best-match at install time) + # b) The marginal dependency confusion attack surface is limited to the closely monitored + # PyTorch nightly index, which is maintained by PyTorch team. Post-processing prunes any packages that are + # ONLY dependencies of torch, eliminating potential attack vectors from torch-exclusive dependencies that + # might only exist on the nightly index. If a package is shared with other dependencies, it's already + # being resolved from PyPI and subject to normal security scanning. + # c) Lockfile generation runs on maintainer machines, not user machines + # d) Generated lockfile pins exact package versions from PyPI + # # 5. Use --no-emit-package=torch to exclude torch from output (installed separately with backend) + # 6. Post-process to prune torch-only dependencies (see prune_torch_only_deps) if [[ "${use_nightly}" == "true" && -n "${TORCH_NIGHTLY_VERSION}" ]]; then local torch_override_file=$(mktemp) echo "torch==${TORCH_NIGHTLY_VERSION}" > "${torch_override_file}" @@ -142,7 +161,7 @@ generate_lockfile() { compile_cmd+=( --prerelease=if-necessary-or-explicit --override "${torch_override_file}" - --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" + --index "https://download.pytorch.org/whl/nightly/cpu" --index-strategy unsafe-best-match # for lockfile generation only, see comment above --no-emit-package torch ) @@ -151,7 +170,11 @@ generate_lockfile() { "${compile_cmd[@]}" rm -f "${torch_override_file}" - echo "✓ Generated ${output_file} (torch ${TORCH_NIGHTLY_VERSION} excluded, install separately)" + + # Prune torch-only dependencies to minimize dependency confusion attack surface + prune_torch_only_deps "${output_file}" + + echo "✓ Generated ${output_file} (torch ${TORCH_NIGHTLY_VERSION} excluded, torch-only deps pruned)" else "${compile_cmd[@]}" echo "✓ Generated ${output_file}" diff --git a/requirements/utils/prune_torch_deps.py b/requirements/utils/prune_torch_deps.py new file mode 100644 index 0000000..baaa023 --- /dev/null +++ b/requirements/utils/prune_torch_deps.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +"""Prune packages that are ONLY dependencies of torch from a lockfile. + +This reduces the dependency confusion attack surface when using unsafe-best-match +by removing any packages that could potentially be resolved from the nightly index only. + +How it works: +1. Parse the lockfile to find all packages and their dependents (the "# via" comments) +2. Identify packages where torch is the ONLY dependent (e.g., "# via torch") +3. Iteratively remove those packages and their exclusive transitive dependencies +4. Packages that are shared with other deps (e.g., "# via torch, transformers") are kept + +This ensures that even if a malicious package were introduced to the nightly index, +it would only be resolved if it's also a dependency of other packages (which would +mean it exists on PyPI and would be caught by normal security scanning). + +Usage: + python prune_torch_deps.py + +Example: + python prune_torch_deps.py requirements/ci/requirements.txt +""" + +from __future__ import annotations + +import re +import sys +from pathlib import Path + + +def parse_lockfile(content: str) -> dict[str, dict]: + """Parse a uv lockfile and extract package blocks with their dependencies. + + Args: + content: The lockfile content as a string. + + Returns: + A dict mapping package names to their block info: + { + "package_name": { + "lines": ["line1", "line2", ...], # All lines in this block + "dependents": ["dep1", "dep2", ...], # Packages that depend on this + } + } + """ + packages: dict[str, dict] = {} + lines = content.splitlines() + + current_pkg: str | None = None + current_lines: list[str] = [] + current_dependents: list[str] = [] + in_via_section = False + + # Regex to match package line: starts with letter, contains == + pkg_pattern = re.compile(r"^([a-zA-Z][a-zA-Z0-9_-]*)") + + for line in lines: + # Check if this is a new package line (starts with letter, not indented) + if line and line[0].isalpha(): + # Save previous package if exists + if current_pkg is not None: + if current_pkg not in packages: + packages[current_pkg] = {"lines": [], "dependents": set()} + packages[current_pkg]["lines"].extend(current_lines) + packages[current_pkg]["dependents"].update(current_dependents) + + # Extract package name (before == or space) + match = pkg_pattern.match(line) + if match: + current_pkg = match.group(1).lower() # Normalize to lowercase + else: + current_pkg = None + + current_lines = [line] + current_dependents = [] + in_via_section = False + + elif current_pkg is not None: + current_lines.append(line) + + # Parse "# via" comments to find dependents + # Single-line format: " # via package_name" + if line.strip().startswith("# via ") and not line.strip() == "# via": + dep = line.strip()[6:].strip() # Remove "# via " + # Handle potential trailing content (markers, etc.) + dep = dep.split()[0] if dep else "" + # Normalize: replace hyphens with underscores for comparison + if dep and not dep.startswith("("): # Skip "(pyproject.toml)" style + current_dependents.append(dep.lower().replace("-", "_")) + in_via_section = False + + # Multi-line format start: " # via" + elif line.strip() == "# via": + in_via_section = True + + # Multi-line format entries: " # package_name" + elif in_via_section and line.strip().startswith("# "): + dep = line.strip()[4:].strip() # Remove "# " + dep = dep.split()[0] if dep else "" + if dep and not dep.startswith("("): + current_dependents.append(dep.lower().replace("-", "_")) + + # End of via section (any non-comment indented line or empty) + elif in_via_section and (not line.strip().startswith("#") or line.strip() == ""): + in_via_section = False + + # Don't forget the last package + if current_pkg is not None: + if current_pkg not in packages: + packages[current_pkg] = {"lines": [], "dependents": set()} + packages[current_pkg]["lines"].extend(current_lines) + packages[current_pkg]["dependents"].update(current_dependents) + + # Convert dependent sets to lists + for pkg_info in packages.values(): + pkg_info["dependents"] = list(pkg_info["dependents"]) + + return packages + + +def find_torch_only_packages(packages: dict[str, dict]) -> set[str]: + """Find packages that are ONLY dependencies of torch (and its exclusive deps). + + Uses iterative approach to handle transitive dependencies: + - First pass: find packages where only "torch" is in dependents + - Subsequent passes: find packages where only already-pruned packages are dependents + + Args: + packages: Dict from parse_lockfile() + + Returns: + Set of package names to prune + """ + pruned: set[str] = set() + max_iterations = 10 # Safety limit + + for _ in range(max_iterations): + newly_pruned = set() + + for pkg_name, pkg_info in packages.items(): + if pkg_name in pruned: + continue + + dependents = pkg_info["dependents"] + if not dependents: + continue + + # Check if all dependents are either "torch" or already pruned + # For first-level deps: package is only required by "torch" + # For transitive deps: package is only required by already-pruned packages + all_dependents_prunable = all( + dep == "torch" or dep.replace("-", "_") in pruned for dep in dependents + ) + + if all_dependents_prunable: + newly_pruned.add(pkg_name) + + if not newly_pruned: + break + + pruned.update(newly_pruned) + + return pruned + + +def prune_packages(content: str, packages_to_prune: set[str]) -> str: + """Remove specified packages from the lockfile content. + + Args: + content: Original lockfile content + packages_to_prune: Set of package names to remove + + Returns: + Updated lockfile content with packages removed + """ + lines = content.splitlines() + result_lines: list[str] = [] + skip_block = False + + # Normalize package names for comparison + normalized_prune = {pkg.lower().replace("-", "_") for pkg in packages_to_prune} + + pkg_pattern = re.compile(r"^([a-zA-Z][a-zA-Z0-9_-]*)") + + for line in lines: + # Check if this is a new package line + if line and line[0].isalpha(): + match = pkg_pattern.match(line) + if match: + pkg_name = match.group(1).lower().replace("-", "_") + skip_block = pkg_name in normalized_prune + else: + skip_block = False + + if not skip_block: + result_lines.append(line) + + return "\n".join(result_lines) + + +def prune_torch_only_deps(lockfile_path: str) -> list[str]: + """Main function to prune torch-only dependencies from a lockfile. + + Args: + lockfile_path: Path to the lockfile to process + + Returns: + List of pruned package names + """ + path = Path(lockfile_path) + content = path.read_text() + + # Parse the lockfile + packages = parse_lockfile(content) + + # Find packages to prune + to_prune = find_torch_only_packages(packages) + + if not to_prune: + return [] + + # Prune and write back + new_content = prune_packages(content, to_prune) + path.write_text(new_content) + + return sorted(to_prune) + + +def main() -> int: + """CLI entry point.""" + if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + return 1 + + lockfile_path = sys.argv[1] + + if not Path(lockfile_path).exists(): + print(f"Error: File not found: {lockfile_path}", file=sys.stderr) + return 1 + + print(f" Post-processing: pruning torch-only dependencies from {lockfile_path}...") + + pruned = prune_torch_only_deps(lockfile_path) + + if pruned: + print(" Pruned torch-only dependencies:") + for pkg in pruned: + print(f" - {pkg}") + print(" Lockfile updated with torch-only dependencies removed") + else: + print(" No torch-only dependencies found to prune") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/finetuning_scheduler/fts.py b/src/finetuning_scheduler/fts.py index 1e2b05a..26b6202 100644 --- a/src/finetuning_scheduler/fts.py +++ b/src/finetuning_scheduler/fts.py @@ -608,9 +608,9 @@ def _reduce_transition(self, strategy: Strategy, decision: bool) -> bool: Returns: bool: The reduced decision across all world processes. """ - decision = torch.tensor(int(decision), device=strategy.root_device) - decision = bool(strategy.reduce(decision, reduce_op=ReduceOp.SUM)) # type:ignore[arg-type] - return decision + decision_tensor = torch.tensor(int(decision), device=strategy.root_device) + reduced_decision = bool(strategy.reduce(decision_tensor, reduce_op=ReduceOp.SUM)) # type:ignore[arg-type] + return reduced_decision def _sync_es_state(self, trainer: "pl.Trainer") -> None: """Synchronize the :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback transition state diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py index e5f77a0..2991acd 100644 --- a/src/finetuning_scheduler/fts_supporters.py +++ b/src/finetuning_scheduler/fts_supporters.py @@ -580,6 +580,7 @@ def _update_pl_lrs(self, pl_lrs_cfg: Dict, lrs_class: FTSLRSchedulerType) -> Dic Returns: Dict: PyTorch Lightning lr scheduler configuration without extra keys """ + supported_keys = set(pl_lrs_cfg.keys()) if self.pl_module.automatic_optimization: supported_keys = {field.name for field in fields(LRSchedulerConfig)} extra_keys = pl_lrs_cfg.keys() - supported_keys @@ -1163,7 +1164,7 @@ def _import_reinit_class( except (ImportError, AttributeError) as err: error_msg = ( "Could not import specified reinitialization configuration class using class_path " - f"({reinit_cfg['class_path']}). Recieved the following error while importing: {err}. Please validate " + f"({reinit_cfg['class_path']}). Received the following error while importing: {err}. Please validate " "specified `class_path` before resubmitting." ) rank_zero_warn(error_msg) @@ -1184,8 +1185,8 @@ def _import_strategy_adapter(strategy_key: str, adapter_map: Dict[str, str]) -> Returns: StrategyAdapter: The custom strategy adapter class to be instantiated. """ + qualname = adapter_map.get(strategy_key, None) try: - qualname = adapter_map.get(strategy_key, None) if not qualname: raise MisconfigurationException( f"Current strategy name ({strategy_key}) does not map to a custom strategy adapter in the" @@ -1198,7 +1199,7 @@ def _import_strategy_adapter(strategy_key: str, adapter_map: Dict[str, str]) -> except (ImportError, AttributeError) as err: error_msg = ( "Could not import the specified custom strategy adapter class using the provided fully qualified class" - f" name ({qualname}). Recieved the following error while importing: {err}. Please validate specified" + f" name ({qualname}). Received the following error while importing: {err}. Please validate specified" " path." ) rank_zero_warn(error_msg) @@ -1229,7 +1230,7 @@ def _optimizer_sanity_chk(self, optimizer_init: Dict) -> None: except Exception as err: error_msg = ( "Could not configure the specified optimizer class using the `init_args` " - f"({optimizer_init['init_args']}). Recieved the following error while sanity checking schedule " + f"({optimizer_init['init_args']}). Received the following error while sanity checking schedule " f"phases: {err}. Please validate specified `init_args` before resubmitting." ) rank_zero_warn(error_msg) @@ -1277,11 +1278,11 @@ def _lr_scheduler_sanity_chk(self, lr_scheduler_init: Dict, is_implicit_mode: bo del test_lr_init["min_lr"] # our mock optimizer will not have any param groups try: assert callable(lrs_class) - testlr = lrs_class(optimizer=_MockOptimizer(), **test_lr_init) + testlr = lrs_class(_MockOptimizer(), **test_lr_init) # type: ignore[call-arg] except Exception as err: error_msg = ( "Could not configure the specified LR scheduler class using the `init_args` " - f"({lr_scheduler_init['init_args']}). Recieved the following error while sanity checking schedule " + f"({lr_scheduler_init['init_args']}). Received the following error while sanity checking schedule " f"phases: {err}. Please validate specified `init_args` before resubmitting." ) rank_zero_warn(error_msg) @@ -1529,7 +1530,7 @@ def _repartition_sharded_optim(optimizer: ParamGroupAddable) -> None: else optimizer.partition_parameters ) optimizer._clear_cache() - optimizer.optim.param_groups = partition_params()[optimizer.rank] + optimizer.optim.param_groups = partition_params()[optimizer.rank] # type: ignore[index] optimizer._sync_param_groups(optimizer.optim.param_groups, optimizer.param_groups) def _restore_latest_lr_state(self, curr_lr_state: Dict, prev_optimizer_lrs: List) -> None: diff --git a/src/finetuning_scheduler/strategy_adapters/_wrap_utils.py b/src/finetuning_scheduler/strategy_adapters/_wrap_utils.py index bc658f0..e13e61f 100644 --- a/src/finetuning_scheduler/strategy_adapters/_wrap_utils.py +++ b/src/finetuning_scheduler/strategy_adapters/_wrap_utils.py @@ -9,13 +9,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Set, Iterator, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Set, Iterator, Tuple from types import resolve_bases from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import _CHECKPOINT_PREFIX import torch -if torch.distributed.is_available(): +# Type checking imports - always available for static analysis +if TYPE_CHECKING or torch.distributed.is_available(): from torch.distributed.fsdp.wrap import _Policy, CustomPolicy diff --git a/src/finetuning_scheduler/strategy_adapters/base.py b/src/finetuning_scheduler/strategy_adapters/base.py index 1da9908..8496183 100644 --- a/src/finetuning_scheduler/strategy_adapters/base.py +++ b/src/finetuning_scheduler/strategy_adapters/base.py @@ -18,7 +18,7 @@ """ from functools import partialmethod from pprint import pformat as pfmt -from typing import Callable, List, Optional, Tuple, Dict +from typing import Callable, Iterable, List, Optional, Tuple, Dict, Union import torch from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -329,22 +329,22 @@ def base_ft_phase( # dispatching pattern for module-specific handling) #################################################################################################################### - def _module_specific_freezing(self, modules: torch.nn.Module) -> None: - """Orchestrates module-specific freezing behavior. Currently only. - + def _module_specific_freezing(self, modules: Union[torch.nn.Module, Iterable[torch.nn.Module]]) -> None: + """Orchestrates module-specific freezing behavior. Currently only :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` layers require special handling. Running statistics tracking for frozen `BatchNorm` layers is conditionally re-enabled here based on the `frozen_bn_track_running_stats` flag. Args: - modules (torch.nn.Module): The modules for which the `BatchNorm` layer running statistics should be enabled. + modules: The modules for which the `BatchNorm` layer running statistics should be enabled. + Can be a single Module or an iterable of Modules. Returns: None """ if self.fts_handle.frozen_bn_track_running_stats: rank_zero_debug("Since `frozen_bn_track_running_stats` is currently set to `True`, FinetuningScheduler" " will set `track_running_stats` to `True` for all `BatchNorm` layers.") - modules = BaseFinetuning.flatten_modules(modules) # type: ignore[assignment] + modules = BaseFinetuning.flatten_modules(modules) for mod in modules: if isinstance(mod, torch.nn.modules.batchnorm._BatchNorm): mod.track_running_stats = True diff --git a/src/finetuning_scheduler/strategy_adapters/fsdp.py b/src/finetuning_scheduler/strategy_adapters/fsdp.py index 1f18b1b..7102969 100644 --- a/src/finetuning_scheduler/strategy_adapters/fsdp.py +++ b/src/finetuning_scheduler/strategy_adapters/fsdp.py @@ -23,11 +23,11 @@ import re import warnings from collections import Counter -from contextlib import contextmanager +from contextlib import AbstractContextManager, contextmanager from copy import deepcopy from functools import partial, partialmethod, wraps from pprint import pformat -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast from typing_extensions import override import torch @@ -44,8 +44,8 @@ from finetuning_scheduler.strategy_adapters.base import StrategyAdapter - -if torch.distributed.is_available(): +# Type checking imports - always available for static analysis +if TYPE_CHECKING or torch.distributed.is_available(): from torch.distributed.fsdp.fully_sharded_data_parallel import ( FLAT_PARAM, FullyShardedDataParallel, @@ -250,8 +250,11 @@ def load_optimizer_state_dict(self, checkpoint_connector: _CheckpointConnector) # rank0_only should be false to enable loading of the optimizer state on all ranks # irrespective of `use_orig_params` mode, we start with a full, unflattened, unsharded, consolidated osd # we then ensure the local osd is properly keyed and transformed for loading into each rank's local optimizer - with _get_full_state_dict_context( - self.pls_handle.model, world_size=self.pls_handle.world_size, rank0_only=False + with cast( + AbstractContextManager[None], + _get_full_state_dict_context( + self.pls_handle.model, world_size=self.pls_handle.world_size, rank0_only=False + ), ): for optimizer, opt_state in zip(self.pls_handle.optimizers, optimizer_states): @@ -281,7 +284,10 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: assert self.pls_handle.model is not None # irrespective of `use_orig_params` mode, we need the full, unflattened, unsharded, consolidated osd - with _get_full_state_dict_context(self.pl_module, world_size=self.pls_handle.world_size, rank0_only=True): + with cast( + AbstractContextManager[None], + _get_full_state_dict_context(self.pl_module, world_size=self.pls_handle.world_size, rank0_only=True), + ): state_dict = FullyShardedDataParallel.optim_state_dict(self.pl_module, optimizer) return state_dict