Skip to content

Commit a7aab9b

Browse files
committed
clarify toggle cli usage, enable more testing for draft mode PRs, add release-based install mode fallback if user specifies commit pin based install in a context that it is not possible
1 parent 7ec2a72 commit a7aab9b

File tree

5 files changed

+38
-4
lines changed

5 files changed

+38
-4
lines changed

.github/workflows/ci_test-full.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ jobs:
3636

3737
cpu:
3838
runs-on: ${{ matrix.os }}
39-
if: github.event.pull_request.draft == false
4039
strategy:
4140
fail-fast: false
4241
matrix:

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ toggle-lightning-mode --mode standalone
142142
toggle-lightning-mode --mode unified
143143
```
144144

145+
> **Note:** If you have the standalone package (`pytorch-lightning`) installed but not the unified package (`lightning`), toggling to unified mode will be prevented. You must install the `lightning` package first before toggling.
146+
145147
This can be useful when:
146148

147149
- You need to adapt existing code to work with a different Lightning package

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222

2323
_PACKAGE_NAME = os.environ.get("PACKAGE_NAME")
2424
_PACKAGE_MODES = ("pytorch", "lightning")
25-
2625
_PATH_ROOT = Path(os.path.abspath(os.path.dirname(__file__)))
2726
_INSTALL_PATHS = {}
2827
for p, d in zip(["source", "tests", "require"], ["src", "tests", "requirements"]):
2928
_INSTALL_PATHS[p] = _PATH_ROOT / d
30-
3129
_CORE_FTS_LOC = _INSTALL_PATHS["source"] / "finetuning_scheduler"
3230
_DYNAMIC_VERSIONING_LOC = _CORE_FTS_LOC / "dynamic_versioning"
3331

@@ -101,7 +99,7 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]:
10199
# Print additional info about Lightning installation method
102100
use_commit = os.environ.get("USE_CI_COMMIT_PIN", "").lower() in ("1", "true", "yes")
103101
if use_commit:
104-
print("Using Lightning from specific commit (dev mode)")
102+
print("Using Lightning from specific commit (dev/ci mode)")
105103
else:
106104
print("Using Lightning from PyPI")
107105

src/finetuning_scheduler/dynamic_versioning/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ def get_lightning_requirement(package_type: str = "unified", use_commit: bool =
100100

101101
project_root = Path(__file__).parent.parent.parent.parent
102102
LIGHTNING_COMMIT_FILE = os.path.join(project_root, "requirements/lightning_pin.txt")
103+
104+
# Check if the commit file exists
105+
if not os.path.exists(LIGHTNING_COMMIT_FILE):
106+
print(f"Warning: USE_CI_COMMIT_PIN is set but {LIGHTNING_COMMIT_FILE} does not exist.")
107+
print(f"Falling back to release-based installation: {package_name}{pkg_info['version']}")
108+
return f"{package_name}{pkg_info['version']}"
109+
103110
with open(LIGHTNING_COMMIT_FILE) as f:
104111
LIGHTNING_COMMIT = f.read().strip()
105112

tests/test_dynamic_versioning_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,34 @@ def test_get_lightning_requirement():
3535
assert "abc123" in req
3636

3737

38+
def test_get_lightning_requirement_missing_pin_file():
39+
"""Test getting the Lightning requirement string when commit pin file is missing."""
40+
# Mock os.path.exists to return False for the lightning_pin.txt file
41+
with patch('os.path.exists', return_value=False), \
42+
patch('builtins.print') as mock_print:
43+
44+
# Test with use_commit=True but missing pin file
45+
req = get_lightning_requirement("unified", True)
46+
47+
# Should fall back to release-based installation
48+
assert "lightning>=" in req
49+
assert "@" not in req
50+
51+
# Should print warning messages
52+
assert any("Warning: USE_CI_COMMIT_PIN is set but" in call_args[0][0]
53+
for call_args in mock_print.call_args_list)
54+
assert any("Falling back to release-based installation" in call_args[0][0]
55+
for call_args in mock_print.call_args_list)
56+
57+
# Reset mock
58+
mock_print.reset_mock()
59+
60+
# Same for standalone package
61+
req = get_lightning_requirement("standalone", True)
62+
assert "pytorch-lightning>=" in req
63+
assert "@" not in req
64+
65+
3866
def test_lightning_package_mapping():
3967
"""Test Lightning package mapping constants."""
4068
assert "lightning.pytorch" in LIGHTNING_PACKAGE_MAPPING

0 commit comments

Comments
 (0)