Skip to content

Commit f006b55

Browse files
authored
SADDEX - implement set with support for expiring members (#350)
1 parent 0afe9a5 commit f006b55

File tree

12 files changed

+229
-116
lines changed

12 files changed

+229
-116
lines changed

docs/about/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ toc_depth: 2
1212
### 🚀 Features
1313

1414
- Add support disable_decoding in async read_response #349
15+
- Implement support for `SADDEX`, using a new set implementation with support for expiring members #350
1516

1617
## v2.26.2
1718

fakeredis/_basefakesocket.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
import redis
99
from redis.connection import DefaultParser
1010

11-
from fakeredis.model import XStream
12-
from fakeredis.model import ZSet
11+
from fakeredis.model import XStream, ZSet, Hash, ExpiringMembersSet
1312
from . import _msgs as msgs
1413
from ._command_args_parsing import extract_args
15-
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash
14+
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem
1615
from ._helpers import (
1716
SimpleError,
1817
valid_response_type,
@@ -392,7 +391,7 @@ def _key_value_type(key: CommandItem) -> SimpleString:
392391
return SimpleString(b"string")
393392
elif isinstance(key.value, list):
394393
return SimpleString(b"list")
395-
elif isinstance(key.value, set):
394+
elif isinstance(key.value, ExpiringMembersSet):
396395
return SimpleString(b"set")
397396
elif isinstance(key.value, ZSet):
398397
return SimpleString(b"zset")

fakeredis/_commands.py

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import re
99
import sys
1010
import time
11-
from typing import Iterable, Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set, Collection
11+
from typing import Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set, Collection
1212

1313
from . import _msgs as msgs
14-
from ._helpers import null_terminate, SimpleError, Database, current_time
14+
from ._helpers import null_terminate, SimpleError, Database
1515

1616
MAX_STRING_SIZE = 512 * 1024 * 1024
1717
SUPPORTED_COMMANDS: Dict[str, "Signature"] = dict() # Dictionary of supported commands name => Signature
@@ -107,82 +107,6 @@ def __bool__(self) -> bool:
107107
__nonzero__ = __bool__ # For Python 2
108108

109109

110-
class Hash:
111-
DECODE_ERROR = msgs.INVALID_HASH_MSG
112-
redis_type = b"hash"
113-
114-
def __init__(self, *args: Any, **kwargs: Any) -> None:
115-
super().__init__(*args, **kwargs)
116-
self._expirations: Dict[bytes, int] = dict()
117-
self._values: Dict[bytes, Any] = dict()
118-
119-
def _expire_keys(self) -> None:
120-
removed = []
121-
now = current_time()
122-
for k in self._expirations:
123-
if self._expirations[k] < now:
124-
self._values.pop(k, None)
125-
removed.append(k)
126-
for k in removed:
127-
self._expirations.pop(k, None)
128-
129-
def set_key_expireat(self, key: bytes, when_ms: int) -> int:
130-
now = current_time()
131-
if when_ms <= now:
132-
self._values.pop(key, None)
133-
self._expirations.pop(key, None)
134-
return 2
135-
self._expirations[key] = when_ms
136-
return 1
137-
138-
def clear_key_expireat(self, key: bytes) -> bool:
139-
return self._expirations.pop(key, None) is not None
140-
141-
def get_key_expireat(self, key: bytes) -> Optional[int]:
142-
self._expire_keys()
143-
return self._expirations.get(key, None)
144-
145-
def __getitem__(self, key: bytes) -> Any:
146-
self._expire_keys()
147-
return self._values.get(key)
148-
149-
def __contains__(self, key: bytes) -> bool:
150-
self._expire_keys()
151-
return self._values.__contains__(key)
152-
153-
def __setitem__(self, key: bytes, value: Any) -> None:
154-
self._expirations.pop(key, None)
155-
self._values[key] = value
156-
157-
def __delitem__(self, key: bytes) -> None:
158-
self._values.pop(key, None)
159-
self._expirations.pop(key, None)
160-
161-
def __len__(self) -> int:
162-
return len(self._values)
163-
164-
def __iter__(self) -> Iterable[bytes]:
165-
return iter(self._values)
166-
167-
def get(self, key: bytes, default: Any = None) -> Any:
168-
return self._values.get(key, default)
169-
170-
def keys(self) -> Iterable[bytes]:
171-
self._expire_keys()
172-
return self._values.keys()
173-
174-
def values(self) -> Iterable[Any]:
175-
return [v for k, v in self.items()]
176-
177-
def items(self) -> Iterable[Tuple[bytes, Any]]:
178-
self._expire_keys()
179-
return self._values.items()
180-
181-
def update(self, values: Dict[bytes, Any]) -> None:
182-
self._expire_keys()
183-
self._values.update(values)
184-
185-
186110
class RedisType:
187111
@classmethod
188112
def decode(cls, *args, **kwargs): # type:ignore

fakeredis/commands_mixins/generic_mixin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
CommandItem,
1515
SortFloat,
1616
delete_keys,
17-
Hash,
1817
)
1918
from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, Database, SimpleString
20-
from fakeredis.model import ZSet
19+
from fakeredis.model import ZSet, Hash, ExpiringMembersSet
2120

2221

2322
class GenericCommandsMixin:
@@ -224,7 +223,7 @@ def scan(self, cursor, *args):
224223

225224
@command(name="SORT", fixed=(Key(),), repeat=(bytes,))
226225
def sort(self, key, *args):
227-
if key.value is not None and not isinstance(key.value, (set, list, ZSet)):
226+
if key.value is not None and not isinstance(key.value, (ExpiringMembersSet, list, ZSet)):
228227
raise SimpleError(msgs.WRONGTYPE_MSG)
229228
(
230229
asc,

fakeredis/commands_mixins/hash_mixin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from fakeredis import _msgs as msgs
77
from fakeredis._command_args_parsing import extract_args
8-
from fakeredis._commands import command, Key, Hash, Int, Float, CommandItem
8+
from fakeredis._commands import command, Key, Int, Float, CommandItem
99
from fakeredis._helpers import SimpleError, OK, casematch, SimpleString
1010
from fakeredis._helpers import current_time
11+
from fakeredis.model import Hash
1112

1213

1314
class HashCommandsMixin:

fakeredis/commands_mixins/set_mixin.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
from fakeredis import _msgs as msgs
55
from fakeredis._commands import command, Key, Int, CommandItem
66
from fakeredis._helpers import OK, SimpleError, casematch, Database, SimpleString
7+
from fakeredis.model import ExpiringMembersSet
78

89

910
def _calc_setop(op: Callable[..., Any], stop_if_missing: bool, key: CommandItem, *keys: CommandItem) -> Any:
1011
if stop_if_missing and not key.value:
1112
return set()
1213
value = key.value
13-
if not isinstance(value, set):
14+
if not isinstance(value, ExpiringMembersSet):
1415
raise SimpleError(msgs.WRONGTYPE_MSG)
1516
ans = value.copy()
1617
for other in keys:
17-
value = other.value if other.value is not None else set()
18-
if not isinstance(value, set):
18+
value = other.value if other.value is not None else ExpiringMembersSet()
19+
if not isinstance(value, ExpiringMembersSet):
1920
raise SimpleError(msgs.WRONGTYPE_MSG)
2021
if stop_if_missing and not value:
2122
return set()
@@ -48,26 +49,26 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
4849
self.version: Tuple[int]
4950
self._db: Database
5051

51-
@command((Key(set), bytes), (bytes,))
52+
@command((Key(ExpiringMembersSet), bytes), (bytes,))
5253
def sadd(self, key: CommandItem, *members: bytes) -> int:
5354
old_size = len(key.value)
5455
key.value.update(members)
5556
key.updated()
5657
return len(key.value) - old_size
5758

58-
@command((Key(set),))
59+
@command((Key(ExpiringMembersSet),))
5960
def scard(self, key: CommandItem) -> int:
6061
return len(key.value)
6162

62-
@command((Key(set),), (Key(set),))
63+
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
6364
def sdiff(self, *keys: CommandItem) -> Any:
6465
return _setop(lambda a, b: a - b, False, None, *keys)
6566

66-
@command((Key(), Key(set)), (Key(set),))
67+
@command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
6768
def sdiffstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
6869
return _setop(lambda a, b: a - b, False, dst, *keys)
6970

70-
@command((Key(set),), (Key(set),))
71+
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
7172
def sinter(self, *keys: CommandItem) -> Any:
7273
res = _setop(lambda a, b: a & b, True, None, *keys)
7374
return res
@@ -89,23 +90,23 @@ def sintercard(self, numkeys: int, *args: bytes) -> int:
8990
res = _setop(lambda a, b: a & b, False, None, *keys)
9091
return len(res) if limit == 0 else min(limit, len(res))
9192

92-
@command((Key(), Key(set)), (Key(set),))
93+
@command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
9394
def sinterstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
9495
return _setop(lambda a, b: a & b, True, dst, *keys)
9596

96-
@command((Key(set), bytes))
97+
@command((Key(ExpiringMembersSet), bytes))
9798
def sismember(self, key: CommandItem, member: bytes) -> int:
9899
return int(member in key.value)
99100

100-
@command((Key(set), bytes), (bytes,))
101+
@command((Key(ExpiringMembersSet), bytes), (bytes,))
101102
def smismember(self, key: CommandItem, *members: bytes) -> List[int]:
102103
return [self.sismember(key, member) for member in members]
103104

104-
@command((Key(set),))
105+
@command((Key(ExpiringMembersSet),))
105106
def smembers(self, key: CommandItem) -> List[bytes]:
106107
return list(key.value)
107108

108-
@command((Key(set, 0), Key(set), bytes))
109+
@command((Key(ExpiringMembersSet, 0), Key(ExpiringMembersSet), bytes))
109110
def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int:
110111
try:
111112
src.value.remove(member)
@@ -117,7 +118,7 @@ def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int:
117118
dst.updated() # TODO: is it updated if member was already present?
118119
return 1
119120

120-
@command((Key(set),), (Int,))
121+
@command((Key(ExpiringMembersSet),), (Int,))
121122
def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
122123
if count is None:
123124
if not key.value:
@@ -135,7 +136,7 @@ def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, Li
135136
key.updated() # Inside the loop because redis special-cases count=0
136137
return items
137138

138-
@command((Key(set),), (Int,))
139+
@command((Key(ExpiringMembersSet),), (Int,))
139140
def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
140141
if count is None:
141142
if not key.value:
@@ -149,7 +150,7 @@ def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[by
149150
items = list(key.value)
150151
return [random.choice(items) for _ in range(-count)]
151152

152-
@command((Key(set), bytes), (bytes,))
153+
@command((Key(ExpiringMembersSet), bytes), (bytes,))
153154
def srem(self, key: CommandItem, *members: bytes) -> int:
154155
old_size = len(key.value)
155156
for member in members:
@@ -159,15 +160,15 @@ def srem(self, key: CommandItem, *members: bytes) -> int:
159160
key.updated()
160161
return deleted
161162

162-
@command((Key(set), Int), (bytes, bytes))
163+
@command((Key(ExpiringMembersSet), Int), (bytes, bytes))
163164
def sscan(self, key: CommandItem, cursor: int, *args: bytes) -> Any:
164165
return self._scan(key.value, cursor, *args)
165166

166-
@command((Key(set),), (Key(set),))
167+
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
167168
def sunion(self, *keys: CommandItem) -> Any:
168169
return _setop(lambda a, b: a | b, False, None, *keys)
169170

170-
@command((Key(), Key(set)), (Key(set),))
171+
@command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
171172
def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
172173
return _setop(lambda a, b: a | b, False, dst, *keys)
173174

@@ -176,19 +177,19 @@ def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
176177
# approximate and store the results in a string. Instead, it is implemented
177178
# on top of sets.
178179

179-
@command((Key(set),), (bytes,))
180+
@command((Key(ExpiringMembersSet),), (bytes,))
180181
def pfadd(self, key: CommandItem, *elements: bytes) -> int:
181182
result = self.sadd(key, *elements)
182183
# Per the documentation:
183184
# - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise.
184185
return 1 if result > 0 else 0
185186

186-
@command((Key(set),), (Key(set),))
187+
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
187188
def pfcount(self, *keys: CommandItem) -> int:
188189
"""Return the approximated cardinality of the set observed by the HyperLogLog at key(s)."""
189190
return len(self.sunion(*keys))
190191

191-
@command((Key(set), Key(set)), (Key(set),))
192+
@command((Key(ExpiringMembersSet), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
192193
def pfmerge(self, dest: CommandItem, *sources: CommandItem) -> SimpleString:
193194
"""Merge N different HyperLogLogs into a single one."""
194195
self.sunionstore(dest, *sources)

fakeredis/commands_mixins/sortedset_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
null_terminate,
2727
Database,
2828
)
29-
from fakeredis.model import ZSet
29+
from fakeredis.model import ZSet, ExpiringMembersSet
3030

3131
SORTED_SET_METHODS = {
3232
"ZUNIONSTORE": lambda s1, s2: s1 | s2,
@@ -391,7 +391,7 @@ def zscore(self, key, member):
391391

392392
@staticmethod
393393
def _get_zset(value):
394-
if isinstance(value, set):
394+
if isinstance(value, ExpiringMembersSet):
395395
zset = ZSet()
396396
for item in value:
397397
zset[item] = 1.0

fakeredis/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from ._expiring_members_set import ExpiringMembersSet
2+
from ._hash import Hash
13
from ._stream import XStream, StreamEntryKey, StreamGroup, StreamRangeTest
24
from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS
35
from ._topk import HeavyKeeper
@@ -13,4 +15,6 @@
1315
"TimeSeriesRule",
1416
"AGGREGATORS",
1517
"HeavyKeeper",
18+
"Hash",
19+
"ExpiringMembersSet",
1620
]

0 commit comments

Comments
 (0)