Skip to content

Commit af1a6e4

Browse files
committed
Add comparison operators to when() conditions
Extend Condition.when() to support Django-style lookup operators via double-underscore suffixes (e.g. when(age__gte=18), when(status__in=[...])). Supported operators: eq, ne, gt, gte, lt, lte, in, notin, contains. Plain key=value usage remains fully backward compatible. Includes 43 new parametrized tests and updated transition docs.
1 parent 5d7267c commit af1a6e4

3 files changed

Lines changed: 260 additions & 9 deletions

File tree

burr/core/action.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -423,26 +423,98 @@ def run(self, state: State, **run_kwargs) -> dict:
423423
def reads(self) -> list[str]:
424424
return self._keys
425425

426+
_OPERATORS = {
427+
"eq": ("==", lambda a, b: a == b),
428+
"ne": ("!=", lambda a, b: a != b),
429+
"lt": ("<", lambda a, b: a < b),
430+
"lte": ("<=", lambda a, b: a <= b),
431+
"gt": (">", lambda a, b: a > b),
432+
"gte": (">=", lambda a, b: a >= b),
433+
"in": ("in", lambda a, b: a in b),
434+
"notin": ("not in", lambda a, b: a not in b),
435+
"contains": ("contains", lambda a, b: b in a),
436+
}
437+
438+
@classmethod
439+
def _parse_kwarg(cls, kwarg_key: str, value):
440+
"""Parse a kwarg key into (state_key, operator_symbol, comparison_func, explicit).
441+
442+
Supports Django-style lookups: ``key__gte=10`` parses as key >= 10.
443+
Plain ``key=value`` defaults to equality (implicit).
444+
445+
Returns a tuple of (state_key, symbol, func, explicit) where explicit
446+
indicates whether an operator suffix was present.
447+
"""
448+
for suffix, (symbol, func) in cls._OPERATORS.items():
449+
dunder = f"__{suffix}"
450+
if kwarg_key.endswith(dunder):
451+
state_key = kwarg_key[: -len(dunder)]
452+
if not state_key:
453+
raise ValueError(
454+
f"Invalid when() key: '{kwarg_key}' — " f"no state key before '__{suffix}'"
455+
)
456+
return state_key, symbol, func, True
457+
return kwarg_key, "=", lambda a, b: a == b, False
458+
426459
@classmethod
427460
def when(cls, **kwargs):
428-
"""Returns a condition that checks if the given keys are in the
429-
state and equal to the given values.
461+
"""Returns a condition that checks state values using optional operators.
430462
431463
You can also refer to this as ``from burr.core import when`` in the API.
432464
433-
:param kwargs: Keyword arguments of keys and values to check -- will be an AND condition
434-
:return: A condition that checks if the given keys are in the state and equal to the given values
465+
Basic equality (unchanged from original)::
466+
467+
when(foo="bar") # state["foo"] == "bar"
468+
when(foo="bar", baz="qux") # state["foo"] == "bar" AND state["baz"] == "qux"
469+
470+
Comparison operators via ``__`` suffix::
471+
472+
when(age__gt=18) # state["age"] > 18
473+
when(age__gte=18) # state["age"] >= 18
474+
when(age__lt=18) # state["age"] < 18
475+
when(age__lte=18) # state["age"] <= 18
476+
when(age__ne=0) # state["age"] != 0
477+
when(age__eq=18) # state["age"] == 18 (explicit)
478+
479+
Membership operators::
480+
481+
when(status__in=["a", "b"]) # state["status"] in ["a", "b"]
482+
when(status__notin=["x", "y"]) # state["status"] not in ["x", "y"]
483+
when(tags__contains="python") # "python" in state["tags"]
484+
485+
Multiple conditions are ANDed together::
486+
487+
when(age__gte=18, status="active") # age >= 18 AND status == "active"
488+
489+
:param kwargs: Keyword arguments with optional ``__operator`` suffixes
490+
:return: A condition that checks all specified constraints (AND)
435491
"""
436-
keys = list(kwargs.keys())
492+
parsed = []
493+
for kwarg_key, value in kwargs.items():
494+
state_key, symbol, func, explicit = cls._parse_kwarg(kwarg_key, value)
495+
parsed.append((state_key, symbol, func, value, explicit))
496+
497+
state_keys = list(dict.fromkeys(p[0] for p in parsed))
437498

438499
def condition_func(state: State) -> bool:
439-
for key, value in kwargs.items():
440-
if state.get(key) != value:
500+
for state_key, _symbol, func, value, _explicit in parsed:
501+
if not func(state.get(state_key), value):
441502
return False
442503
return True
443504

444-
name = f"{', '.join(f'{key}={value}' for key, value in sorted(kwargs.items()))}"
445-
return Condition(keys, condition_func, name=name)
505+
name_parts = []
506+
for state_key, symbol, _func, value, explicit in sorted(parsed, key=lambda p: p[0]):
507+
if not explicit:
508+
# Backward-compatible format: key=value (no repr, no spaces)
509+
name_parts.append(f"{state_key}={value}")
510+
elif symbol.isalnum() or " " in symbol:
511+
# Word operators like "in", "not in", "contains"
512+
name_parts.append(f"{state_key} {symbol} {value!r}")
513+
else:
514+
# Symbol operators like >=, !=, etc.
515+
name_parts.append(f"{state_key}{symbol}{value!r}")
516+
name = ", ".join(name_parts)
517+
return Condition(state_keys, condition_func, name=name)
446518

447519
def __repr__(self):
448520
return f"condition: {self._name}"

docs/concepts/transitions.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,36 @@ Conditions have a few APIs, but the most common are the three convenience functi
5454
)
5555
5656
57+
``when()`` also supports comparison operators via Django-style ``__`` suffixes:
58+
59+
.. code-block:: python
60+
61+
from burr.core import when
62+
with_transitions(
63+
("check", "adult", when(age__gte=18)), # age >= 18
64+
("check", "child", when(age__lt=18)), # age < 18
65+
("check", "valid", when(score__gt=0, score__lte=100)), # 0 < score <= 100
66+
("check", "active", when(status__in=["active", "pending"])), # membership
67+
("check", "tagged", when(tags__contains="python")), # collection contains value
68+
("check", "clean", when(status__notin=["banned", "suspended"])), # not in
69+
("check", "changed", when(status__ne="initial")), # not equal
70+
)
71+
72+
Available operators:
73+
74+
- ``key=value`` — exact equality (default, unchanged)
75+
- ``key__eq=value`` — explicit equality
76+
- ``key__ne=value``not equal
77+
- ``key__gt=value`` — greater than
78+
- ``key__gte=value`` — greater than or equal
79+
- ``key__lt=value`` — less than
80+
- ``key__lte=value`` — less than or equal
81+
- ``key__in=[values]`` — value is in the given collection
82+
- ``key__notin=[values]`` — value is not in the given collection
83+
- ``key__contains=value`` — collection/string in state contains the value
84+
85+
Multiple keyword arguments are ANDed together. For more complex expressions, use ``expr()``.
86+
5787
Conditions are evaluated in the order they are specified, and the first one that evaluates to True will be the transition that is selected
5888
when determining which action to run next. If no condition evaluates to ``True``, the application execution will stop early.
5989

tests/core/test_action.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,155 @@ def test_condition_when_complex():
126126
assert cond.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False}
127127

128128

129+
# --- when() operator tests ---
130+
131+
132+
@pytest.mark.parametrize(
133+
"kwargs,state_dict,expected",
134+
[
135+
# __eq (explicit equality)
136+
({"age__eq": 18}, {"age": 18}, True),
137+
({"age__eq": 18}, {"age": 19}, False),
138+
# __ne (not equal)
139+
({"age__ne": 0}, {"age": 5}, True),
140+
({"age__ne": 0}, {"age": 0}, False),
141+
# __gt (greater than)
142+
({"age__gt": 18}, {"age": 19}, True),
143+
({"age__gt": 18}, {"age": 18}, False),
144+
({"age__gt": 18}, {"age": 17}, False),
145+
# __gte (greater than or equal)
146+
({"age__gte": 18}, {"age": 19}, True),
147+
({"age__gte": 18}, {"age": 18}, True),
148+
({"age__gte": 18}, {"age": 17}, False),
149+
# __lt (less than)
150+
({"age__lt": 18}, {"age": 17}, True),
151+
({"age__lt": 18}, {"age": 18}, False),
152+
({"age__lt": 18}, {"age": 19}, False),
153+
# __lte (less than or equal)
154+
({"age__lte": 18}, {"age": 17}, True),
155+
({"age__lte": 18}, {"age": 18}, True),
156+
({"age__lte": 18}, {"age": 19}, False),
157+
# __in (membership)
158+
({"status__in": ["active", "pending"]}, {"status": "active"}, True),
159+
({"status__in": ["active", "pending"]}, {"status": "pending"}, True),
160+
({"status__in": ["active", "pending"]}, {"status": "banned"}, False),
161+
# __notin (not in)
162+
({"status__notin": ["banned", "suspended"]}, {"status": "active"}, True),
163+
({"status__notin": ["banned", "suspended"]}, {"status": "banned"}, False),
164+
# __contains (collection contains value)
165+
({"tags__contains": "python"}, {"tags": ["python", "java"]}, True),
166+
({"tags__contains": "go"}, {"tags": ["python", "java"]}, False),
167+
({"text__contains": "hello"}, {"text": "say hello world"}, True),
168+
({"text__contains": "goodbye"}, {"text": "say hello world"}, False),
169+
],
170+
ids=[
171+
"eq-match",
172+
"eq-no-match",
173+
"ne-different",
174+
"ne-same",
175+
"gt-above",
176+
"gt-equal",
177+
"gt-below",
178+
"gte-above",
179+
"gte-equal",
180+
"gte-below",
181+
"lt-below",
182+
"lt-equal",
183+
"lt-above",
184+
"lte-below",
185+
"lte-equal",
186+
"lte-above",
187+
"in-first",
188+
"in-second",
189+
"in-missing",
190+
"notin-absent",
191+
"notin-present",
192+
"contains-list-match",
193+
"contains-list-no-match",
194+
"contains-str-match",
195+
"contains-str-no-match",
196+
],
197+
)
198+
def test_condition_when_operators(kwargs, state_dict, expected):
199+
cond = Condition.when(**kwargs)
200+
assert cond.run(State(state_dict)) == {Condition.KEY: expected}
201+
202+
203+
@pytest.mark.parametrize(
204+
"kwargs,expected_reads",
205+
[
206+
({"age__gte": 18}, ["age"]),
207+
({"status__in": ["a"]}, ["status"]),
208+
({"tags__contains": "x"}, ["tags"]),
209+
({"age__gte": 18, "status": "active"}, ["age", "status"]),
210+
# same key with different operators
211+
({"age__gte": 10, "age__lt": 20}, ["age"]),
212+
],
213+
ids=["gte", "in", "contains", "mixed", "same-key-two-ops"],
214+
)
215+
def test_condition_when_operators_reads(kwargs, expected_reads):
216+
cond = Condition.when(**kwargs)
217+
assert sorted(cond.reads) == sorted(expected_reads)
218+
219+
220+
@pytest.mark.parametrize(
221+
"kwargs,expected_name",
222+
[
223+
({"age__gte": 18}, "age>=18"),
224+
({"age__lt": 5}, "age<5"),
225+
({"age__ne": 0}, "age!=0"),
226+
({"status__in": ["a", "b"]}, "status in ['a', 'b']"),
227+
({"status__notin": ["x"]}, "status not in ['x']"),
228+
({"tags__contains": "py"}, "tags contains 'py'"),
229+
# plain equality still uses old format
230+
({"foo": "bar"}, "foo=bar"),
231+
({"foo": "bar", "baz": "qux"}, "baz=qux, foo=bar"),
232+
],
233+
ids=["gte", "lt", "ne", "in", "notin", "contains", "plain-eq", "plain-multi"],
234+
)
235+
def test_condition_when_operators_name(kwargs, expected_name):
236+
cond = Condition.when(**kwargs)
237+
assert cond.name == expected_name
238+
239+
240+
def test_condition_when_operators_combined():
241+
"""Test multiple operators ANDed together."""
242+
cond = Condition.when(age__gte=18, status="active", score__lt=100)
243+
assert cond.run(State({"age": 20, "status": "active", "score": 50})) == {Condition.KEY: True}
244+
assert cond.run(State({"age": 17, "status": "active", "score": 50})) == {Condition.KEY: False}
245+
assert cond.run(State({"age": 20, "status": "inactive", "score": 50})) == {Condition.KEY: False}
246+
assert cond.run(State({"age": 20, "status": "active", "score": 100})) == {Condition.KEY: False}
247+
248+
249+
def test_condition_when_operators_with_invert():
250+
"""Ensure operator-based conditions work with ~ (invert)."""
251+
cond = ~Condition.when(age__gte=18)
252+
assert cond.run(State({"age": 17})) == {Condition.KEY: True}
253+
assert cond.run(State({"age": 18})) == {Condition.KEY: False}
254+
255+
256+
def test_condition_when_operators_with_or():
257+
"""Ensure operator-based conditions work with | (or)."""
258+
cond = Condition.when(age__lt=13) | Condition.when(age__gte=65)
259+
assert cond.run(State({"age": 10})) == {Condition.KEY: True}
260+
assert cond.run(State({"age": 70})) == {Condition.KEY: True}
261+
assert cond.run(State({"age": 30})) == {Condition.KEY: False}
262+
263+
264+
def test_condition_when_operators_with_and():
265+
"""Ensure operator-based conditions work with & (and)."""
266+
cond = Condition.when(age__gte=18) & Condition.when(age__lt=65)
267+
assert cond.run(State({"age": 30})) == {Condition.KEY: True}
268+
assert cond.run(State({"age": 17})) == {Condition.KEY: False}
269+
assert cond.run(State({"age": 65})) == {Condition.KEY: False}
270+
271+
272+
def test_condition_when_invalid_key():
273+
"""Empty state key before operator suffix should raise."""
274+
with pytest.raises(ValueError, match="no state key"):
275+
Condition.when(__gte=18)
276+
277+
129278
def test_condition_default():
130279
cond = default
131280
assert cond.name == "default"

0 commit comments

Comments
 (0)