Skip to content

Commit d4d654f

Browse files
committed
New filtering syntax using a more natural python style
1 parent b9d05de commit d4d654f

2 files changed

Lines changed: 115 additions & 10 deletions

File tree

tests/test_filtering.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from tscat.filtering import Predicate, Comparison, Field, Attribute, Has, Match, Not, All, Any, In, UUID, \
1010
InCatalogue, PredicateRecursionError, CatalogueFilterError
1111

12+
from tscat import filtering
13+
1214
import datetime as dt
1315

1416
dates = [
@@ -38,7 +40,10 @@ class TestFilterRepr(unittest.TestCase):
3840
"All(Comparison('<=', Field('fieldName'), 'value'), Match(Field('fieldName'), '^mat[ch]{2}\\\\n$'))"),
3941
(In("Value", Field("FieldName")), "In('Value', Field('FieldName'))"),
4042
(InCatalogue(create_catalogue('Name', 'Author', uuid='957d65ae-f278-48f5-aab1-8cf50efeadef')),
41-
"InCatalogue(Catalogue(name=Name, author=Author, uuid=957d65ae-f278-48f5-aab1-8cf50efeadef, tags=[], predicate=None) attributes())")
43+
"InCatalogue(Catalogue(name=Name, author=Author, uuid=957d65ae-f278-48f5-aab1-8cf50efeadef, tags=[], predicate=None) attributes())"),
44+
(filtering.catalogue.some_field == 'value', "Comparison('==', Attribute('some_field'), 'value')"),
45+
(filtering.events.some_field == 'value', "Comparison('==', Attribute('some_field'), 'value')"),
46+
(filtering.events.another_field.matches(r'^pattern$'), "Match(Attribute('another_field'), '^pattern$')"),
4247
)
4348
@unpack
4449
def test_predicate_repr(self, pred: Predicate, expected: str) -> None:
@@ -153,6 +158,25 @@ def test_logical_combinations(self, pred, idx):
153158
event_list = get_events(All(pred))
154159
self.assertListEqual(event_list, [events[i] for i in idx])
155160

161+
def test_new_syntax(self):
162+
# Test the new syntax for filtering
163+
event_list = get_events(filtering.events.author == 'Patrick')
164+
self.assertListEqual(event_list, [events[0]])
165+
166+
event_list = get_events((filtering.events.author == 'Patrick') | (filtering.events.author == 'Alexis'))
167+
self.assertListEqual(event_list, [events[0], events[1]])
168+
169+
event_list = get_events((filtering.events.author == 'Patrick') & (filtering.events.a == 1))
170+
self.assertListEqual(event_list, [events[0]])
171+
172+
event_list = get_events(filtering.events.s.matches(r'^Go.*'))
173+
self.assertListEqual(event_list, [events[2]])
174+
175+
event_list = get_events(~(filtering.events.s.matches(r'^Go.*') & (filtering.events.h == 30)))
176+
self.assertListEqual(event_list, [events[0], events[1]])
177+
178+
179+
156180
def test_get_only_manually_added_events_from_dynamic_catalogue(self):
157181
cat = create_catalogue('T', 'A')
158182
cat.predicate = Comparison("==", Field('author'), 'Patrick')

tscat/filtering.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,83 @@
77
if TYPE_CHECKING:
88
from . import _Catalogue
99

10+
MemberValueType = Union[str, int, float, dt.datetime, bool]
1011

11-
class Field:
12+
class _Member:
1213
def __init__(self, name: str):
1314
self.value = name
1415

16+
def __eq__(self, other: MemberValueType) -> 'Predicate':
17+
return Comparison('==', self, other)
18+
19+
def __ne__(self, other) -> 'Predicate':
20+
return Comparison('!=', self, other)
21+
22+
def __gt__(self, other) -> 'Predicate':
23+
return Comparison('>', self, other)
24+
25+
def __lt__(self, other) -> 'Predicate':
26+
return Comparison('<', self, other)
27+
28+
def __ge__(self, other) -> 'Predicate':
29+
return Comparison('>=', self, other)
30+
31+
def __le__(self, other) -> 'Predicate':
32+
return Comparison('<=', self, other)
33+
34+
def matches(self, value: str) -> 'Predicate':
35+
return Match(self, value)
36+
37+
38+
class Field(_Member):
39+
def __init__(self, name: str):
40+
super().__init__(name)
41+
1542
def __repr__(self):
1643
return f"Field('{self.value}')"
1744

1845

19-
class Attribute:
46+
class Attribute(_Member):
2047
def __init__(self, name: str):
21-
self.value = name
48+
super().__init__(name)
2249

2350
def __repr__(self):
2451
return f"Attribute('{self.value}')"
2552

53+
def exists(self) -> 'Predicate':
54+
return Has(self)
2655

2756
class Predicate:
2857
def __eq__(self, o):
2958
return repr(self) == repr(o)
3059

60+
def __and__(self, other):
61+
if isinstance(other, Predicate):
62+
return All(self, other)
63+
elif isinstance(other, (list, tuple)) and all(isinstance(item, Predicate) for item in other):
64+
return All(self, *other)
65+
else:
66+
raise TypeError(f"Cannot combine {type(self).__name__} with {type(other).__name__}")
67+
68+
def __or__(self, other):
69+
if isinstance(other, Predicate):
70+
return Any(self, other)
71+
elif isinstance(other, (list, tuple)) and all(isinstance(item, Predicate) for item in other):
72+
return Any(self, *other)
73+
else:
74+
raise TypeError(f"Cannot combine {type(self).__name__} with {type(other).__name__}")
75+
76+
def __invert__(self):
77+
return Not(self)
78+
3179

3280
class Comparison(Predicate):
3381
def __init__(self,
3482
op: Union[Literal['>'], Literal['>='],
35-
Literal['<'], Literal['<='],
36-
Literal['=='], Literal['!=']],
37-
lhs: Union[Field, Attribute],
38-
rhs: Union[str, int, float, dt.datetime, bool]):
83+
Literal['<'], Literal['<='],
84+
Literal['=='], Literal['!=']],
85+
lhs: _Member,
86+
rhs: MemberValueType):
3987
self._op = op
4088
self._lhs = lhs
4189
self._rhs = rhs
@@ -46,7 +94,7 @@ def __repr__(self):
4694

4795
class Match(Predicate):
4896
def __init__(self,
49-
lhs: Union[Field, Attribute],
97+
lhs: _Member,
5098
rhs: str): # regex
5199
self._lhs = lhs
52100
self._rhs = rhs
@@ -88,7 +136,7 @@ def __repr__(self):
88136

89137

90138
class In(Predicate):
91-
def __init__(self, lhs: str, rhs: Union[Field, Attribute]):
139+
def __init__(self, lhs: str, rhs: _Member):
92140
self._lhs = lhs
93141
self._rhs = rhs
94142

@@ -120,3 +168,36 @@ def __init__(self, message: str, predicate: Predicate):
120168
class CatalogueFilterError(Exception):
121169
def __init__(self, message: str):
122170
super().__init__(message)
171+
172+
173+
174+
class _Catalogue:
175+
def __init__(self):
176+
pass
177+
178+
def __contains__(self, item: "_Event") -> Predicate:
179+
return InCatalogue(item)
180+
181+
def __getattr__(self, item) -> _Member:
182+
if item in ('name', 'author', 'uuid', 'tags', 'predicate', 'attributes'):
183+
return Field(item)
184+
return Attribute(item)
185+
186+
187+
class _Events:
188+
def __init__(self):
189+
pass
190+
191+
def __contains__(self, item: str) -> Predicate:
192+
return self[item].exists()
193+
194+
def __getattr__(self, item) -> _Member:
195+
if item in ('start', 'stop', 'author', 'tags', 'products', 'rating', 'uuid'):
196+
return Field(item)
197+
return Attribute(item)
198+
199+
200+
# tokens to create predicates from Python code
201+
catalogue = _Catalogue()
202+
events = _Events()
203+
event = events

0 commit comments

Comments
 (0)