Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ jobs:
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: 0 # Ensures the full history is available
fetch-tags: true # Ensures tags are fetched
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
Expand Down
76 changes: 4 additions & 72 deletions nam/train/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,29 @@
Version utility
"""

from typing import Optional as _Optional

from .._version import __version__


class IncomparableVersionError(ValueError):
"""
Error raised when two versions can't be compared.
"""

pass
from .._version import __version__ as _package_version


class Version:
def __init__(self, major: int, minor: int, patch: int, dev: _Optional[str] = None):
def __init__(self, major: int, minor: int, patch: int):
self.major = major
self.minor = minor
self.patch = patch
self.dev = dev
self.dev_int = self._parse_dev_int(dev)

@classmethod
def from_string(cls, s: str):
def special_case(s: str) -> _Optional[dict]:
"""
Regretful hacks
"""
# It seems like the git repo isn't accessible to setuptools_scm's version
# guesser, so it comes up with this during install:
if s == "0.1.dev1":
# This will be fine.
return {
"major": 0,
"minor": 1,
"patch": 0,
"dev": "dev1"
}
return None

if special_case(s) is not None:
return cls(**special_case(s))

# Typical
parts = s.split(".")
if len(parts) == 3: # e.g. "0.7.1"
dev = None
elif len(parts) == 4: # e.g. "0.7.1.dev7"
dev = parts[3]
else:
raise ValueError(f"Invalid version string {s}")
try:
major, minor, patch = [int(x) for x in parts[:3]]
except ValueError as e:
raise ValueError(f"Failed to parse version from string '{s}':\n{e}")
return cls(major=major, minor=minor, patch=patch, dev=dev)
return cls(major=major, minor=minor, patch=patch)

def __eq__(self, other) -> bool:
return (
self.major == other.major
and self.minor == other.minor
and self.patch == other.patch
and self.dev == other.dev
)

def __lt__(self, other) -> bool:
Expand All @@ -79,45 +40,16 @@ def __lt__(self, other) -> bool:
return self.minor < other.minor
if self.patch != other.patch:
return self.patch < other.patch
if self.dev != other.dev:
# None is defined as least
if self.dev is None and other.dev is not None:
return True
elif self.dev is not None and other.dev is None:
return False
assert self.dev is not None
assert other.dev is not None
if self.dev_int is None:
raise IncomparableVersionError(
f"Version {str(self)} has incomparable dev version {self.dev}"
)
if other.dev_int is None:
raise IncomparableVersionError(
f"Version {str(other)} has incomparable dev version {other.dev}"
)
return self.dev_int < other.dev_int
raise RuntimeError(
f"Unhandled comparison between versions {str(self)} and {str(other)}"
)

def __str__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"

def _parse_dev_int(self, dev: _Optional[str]) -> _Optional[int]:
"""
Turn the string into an int that can be compared if possible.
"""
if dev is None:
return None
if not isinstance(dev, str):
raise TypeError(f"Invalid dev string type {type(dev)}")
if not dev.startswith("dev") or len(dev) <= 3: # "misc", "dev", etc
return None
return int(dev.removeprefix("dev"))


PROTEUS_VERSION = Version(4, 0, 0)


def get_current_version() -> Version:
return Version.from_string(__version__)
return Version.from_string(_package_version)
20 changes: 0 additions & 20 deletions tests/test_nam/test_train/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,11 @@
from nam.train import _version


def test_dev_int():
"""
Assert that dev_int is properly parsed
"""
assert _version.Version(0, 0, 0).dev_int is None
assert _version.Version(0, 0, 0, "dev").dev_int is None
assert _version.Version(0, 0, 0, "misc").dev_int is None
assert _version.Version(0, 0, 0, "dev11").dev_int == 11


def test_eq():
assert _version.Version(0, 0, 0) == _version.Version(0, 0, 0)
assert _version.Version(0, 0, 0) != _version.Version(0, 0, 1)
assert _version.Version(0, 0, 0) != _version.Version(0, 1, 0)
assert _version.Version(0, 0, 0) != _version.Version(1, 0, 0)
assert _version.Version(0, 0, 0) != _version.Version(0, 0, 0, dev="dev0")
assert _version.Version(0, 0, 0) != _version.Version(0, 0, 0, dev="dev1")


def test_lt():
Expand All @@ -40,14 +28,6 @@ def test_lt():
assert not _version.Version(1, 2, 3) < _version.Version(0, 4, 5)


def test_lt_incomparable():
"""
Assert that the error is properly raised for incomparable versions
"""
with _pytest.raises(_version.IncomparableVersionError):
_version.Version(0, 0, 0, "incomparable") < _version.Version(0, 0, 0, "dev1")


def test_current_version():
"""
Test that the current version is valid
Expand Down