Skip to content

Commit 9d533dc

Browse files
Make (N)EVR(A) objects comparable (#379)
Make (N)EVR(A) objects comparable Related to packit/packit-service#2378. In the end it will probably not be needed for the sidetags related stuff, but it can be useful nevertheless. Reviewed-by: Laura Barcziová
2 parents 9f26a94 + b842ff4 commit 9d533dc

File tree

3 files changed

+156
-8
lines changed

3 files changed

+156
-8
lines changed

specfile/changelog.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import subprocess
1212
from typing import List, Optional, Union, overload
1313

14-
import rpm
15-
1614
from specfile.exceptions import SpecfileException
1715
from specfile.formatter import formatted
1816
from specfile.macros import Macros
@@ -282,10 +280,9 @@ def filter(
282280

283281
def parse_evr(s):
284282
try:
285-
evr = EVR.from_string(s)
283+
return EVR.from_string(s)
286284
except SpecfileException:
287-
return "0", "0", ""
288-
return str(evr.epoch), evr.version or "0", evr.release
285+
return EVR(version="0")
289286

290287
if since is None:
291288
start_index = 0
@@ -294,7 +291,7 @@ def parse_evr(s):
294291
(
295292
i
296293
for i, e in enumerate(self.data)
297-
if rpm.labelCompare(parse_evr(e.evr), parse_evr(since)) >= 0
294+
if parse_evr(e.evr) >= parse_evr(since)
298295
),
299296
len(self.data) + 1,
300297
)
@@ -305,7 +302,7 @@ def parse_evr(s):
305302
(
306303
i + 1
307304
for i, e in reversed(list(enumerate(self.data)))
308-
if rpm.labelCompare(parse_evr(e.evr), parse_evr(until)) <= 0
305+
if parse_evr(e.evr) <= parse_evr(until)
309306
),
310307
0,
311308
)

specfile/utils.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import sys
77
from typing import TYPE_CHECKING, Tuple
88

9+
import rpm
10+
911
from specfile.constants import ARCH_NAMES
1012
from specfile.exceptions import SpecfileException, UnterminatedMacroException
1113
from specfile.formatter import formatted
@@ -28,10 +30,41 @@ def _key(self) -> tuple:
2830
def __hash__(self) -> int:
2931
return hash(self._key())
3032

33+
def _rpm_evr_tuple(self) -> Tuple[str, str, str]:
34+
return str(self.epoch), self.version or "0", self.release
35+
36+
def _cmp(self, other: "EVR") -> int:
37+
return rpm.labelCompare(self._rpm_evr_tuple(), other._rpm_evr_tuple())
38+
39+
def __lt__(self, other: object) -> bool:
40+
if type(other) is not self.__class__:
41+
return NotImplemented
42+
return self._cmp(other) < 0
43+
44+
def __le__(self, other: object) -> bool:
45+
if type(other) is not self.__class__:
46+
return NotImplemented
47+
return self._cmp(other) <= 0
48+
3149
def __eq__(self, other: object) -> bool:
3250
if type(other) is not self.__class__:
3351
return NotImplemented
34-
return self._key() == other._key()
52+
return self._cmp(other) == 0
53+
54+
def __ne__(self, other: object) -> bool:
55+
if type(other) is not self.__class__:
56+
return NotImplemented
57+
return self._cmp(other) != 0
58+
59+
def __ge__(self, other: object) -> bool:
60+
if type(other) is not self.__class__:
61+
return NotImplemented
62+
return self._cmp(other) >= 0
63+
64+
def __gt__(self, other: object) -> bool:
65+
if type(other) is not self.__class__:
66+
return NotImplemented
67+
return self._cmp(other) > 0
3568

3669
@formatted
3770
def __repr__(self) -> str:
@@ -65,6 +98,44 @@ def __init__(
6598
def _key(self) -> tuple:
6699
return self.name, self.epoch, self.version, self.release
67100

101+
def __lt__(self, other: object) -> bool:
102+
if type(other) is not self.__class__:
103+
return NotImplemented
104+
if self.name != other.name:
105+
return NotImplemented
106+
return self._cmp(other) < 0
107+
108+
def __le__(self, other: object) -> bool:
109+
if type(other) is not self.__class__:
110+
return NotImplemented
111+
if self.name != other.name:
112+
return NotImplemented
113+
return self._cmp(other) <= 0
114+
115+
def __eq__(self, other: object) -> bool:
116+
if type(other) is not self.__class__:
117+
return NotImplemented
118+
return self.name == other.name and self._cmp(other) == 0
119+
120+
def __ne__(self, other: object) -> bool:
121+
if type(other) is not self.__class__:
122+
return NotImplemented
123+
return self.name != other.name or self._cmp(other) != 0
124+
125+
def __ge__(self, other: object) -> bool:
126+
if type(other) is not self.__class__:
127+
return NotImplemented
128+
if self.name != other.name:
129+
return NotImplemented
130+
return self._cmp(other) >= 0
131+
132+
def __gt__(self, other: object) -> bool:
133+
if type(other) is not self.__class__:
134+
return NotImplemented
135+
if self.name != other.name:
136+
return NotImplemented
137+
return self._cmp(other) > 0
138+
68139
@formatted
69140
def __repr__(self) -> str:
70141
return (
@@ -101,6 +172,50 @@ def __init__(
101172
def _key(self) -> tuple:
102173
return self.name, self.epoch, self.version, self.release, self.arch
103174

175+
def __lt__(self, other: object) -> bool:
176+
if type(other) is not self.__class__:
177+
return NotImplemented
178+
if self.name != other.name or self.arch != other.arch:
179+
return NotImplemented
180+
return self._cmp(other) < 0
181+
182+
def __le__(self, other: object) -> bool:
183+
if type(other) is not self.__class__:
184+
return NotImplemented
185+
if self.name != other.name or self.arch != other.arch:
186+
return NotImplemented
187+
return self._cmp(other) <= 0
188+
189+
def __eq__(self, other: object) -> bool:
190+
if type(other) is not self.__class__:
191+
return NotImplemented
192+
return (
193+
self.name == other.name
194+
and self.arch == other.arch
195+
and self._cmp(other) == 0
196+
)
197+
198+
def __ne__(self, other: object) -> bool:
199+
if type(other) is not self.__class__:
200+
return NotImplemented
201+
return (
202+
self.name != other.name or self.arch != other.arch or self._cmp(other) != 0
203+
)
204+
205+
def __ge__(self, other: object) -> bool:
206+
if type(other) is not self.__class__:
207+
return NotImplemented
208+
if self.name != other.name or self.arch != other.arch:
209+
return NotImplemented
210+
return self._cmp(other) >= 0
211+
212+
def __gt__(self, other: object) -> bool:
213+
if type(other) is not self.__class__:
214+
return NotImplemented
215+
if self.name != other.name or self.arch != other.arch:
216+
return NotImplemented
217+
return self._cmp(other) > 0
218+
104219
@formatted
105220
def __repr__(self) -> str:
106221
return (

tests/unit/test_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,42 @@ def test_get_filename_from_location(location, filename):
3232
assert get_filename_from_location(location) == filename
3333

3434

35+
def test_EVR_compare():
36+
assert EVR(version="0") == EVR(version="0")
37+
assert EVR(version="0", release="1") != EVR(version="0", release="2")
38+
assert EVR(version="12.0", release="1") <= EVR(version="12.0", release="1")
39+
assert EVR(version="12.0", release="1") <= EVR(version="12.0", release="2")
40+
assert EVR(epoch=2, version="56.8", release="5") > EVR(
41+
epoch=1, version="99.2", release="2"
42+
)
43+
44+
45+
def test_NEVR_compare():
46+
assert NEVR(name="test", version="1", release="1") == NEVR(
47+
name="test", version="1", release="1"
48+
)
49+
assert NEVR(name="test", version="3", release="1") != NEVR(
50+
name="test2", version="3", release="1"
51+
)
52+
with pytest.raises(TypeError):
53+
NEVR(name="test", version="3", release="1") > NEVR(
54+
name="test2", version="1", release="2"
55+
)
56+
57+
58+
def test_NEVRA_compare():
59+
assert NEVRA(name="test", version="1", release="1", arch="x86_64") == NEVRA(
60+
name="test", version="1", release="1", arch="x86_64"
61+
)
62+
assert NEVRA(name="test", version="2", release="1", arch="x86_64") != NEVRA(
63+
name="test", version="2", release="1", arch="aarch64"
64+
)
65+
with pytest.raises(TypeError):
66+
NEVRA(name="test", version="1", release="1", arch="aarch64") < NEVRA(
67+
name="test", version="2", release="1", arch="x86_64"
68+
)
69+
70+
3571
@pytest.mark.parametrize(
3672
"evr, result",
3773
[

0 commit comments

Comments
 (0)