Skip to content

Commit e11150a

Browse files
authored
Merge pull request #743 from rootart/issue/597
Add support for the set functions from issue #597
2 parents ce47b30 + f34935c commit e11150a

File tree

5 files changed

+657
-3
lines changed

5 files changed

+657
-3
lines changed

changelog.d/730.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support for sets and support basic operations, sadd, scard, sdiff, sdiffstore, sinter, sinterstore, smismember, sismember, smembers, smove, spop, srandmember, srem, sscan, sscan_iter, sunion, sunionstore

django_redis/cache.py

+68
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,74 @@ def close(self, **kwargs):
185185
def touch(self, *args, **kwargs):
186186
return self.client.touch(*args, **kwargs)
187187

188+
@omit_exception
189+
def sadd(self, *args, **kwargs):
190+
return self.client.sadd(*args, **kwargs)
191+
192+
@omit_exception
193+
def scard(self, *args, **kwargs):
194+
return self.client.scard(*args, **kwargs)
195+
196+
@omit_exception
197+
def sdiff(self, *args, **kwargs):
198+
return self.client.sdiff(*args, **kwargs)
199+
200+
@omit_exception
201+
def sdiffstore(self, *args, **kwargs):
202+
return self.client.sdiffstore(*args, **kwargs)
203+
204+
@omit_exception
205+
def sinter(self, *args, **kwargs):
206+
return self.client.sinter(*args, **kwargs)
207+
208+
@omit_exception
209+
def sinterstore(self, *args, **kwargs):
210+
return self.client.sinterstore(*args, **kwargs)
211+
212+
@omit_exception
213+
def sismember(self, *args, **kwargs):
214+
return self.client.sismember(*args, **kwargs)
215+
216+
@omit_exception
217+
def smembers(self, *args, **kwargs):
218+
return self.client.smembers(*args, **kwargs)
219+
220+
@omit_exception
221+
def smove(self, *args, **kwargs):
222+
return self.client.smove(*args, **kwargs)
223+
224+
@omit_exception
225+
def spop(self, *args, **kwargs):
226+
return self.client.spop(*args, **kwargs)
227+
228+
@omit_exception
229+
def srandmember(self, *args, **kwargs):
230+
return self.client.srandmember(*args, **kwargs)
231+
232+
@omit_exception
233+
def srem(self, *args, **kwargs):
234+
return self.client.srem(*args, **kwargs)
235+
236+
@omit_exception
237+
def sscan(self, *args, **kwargs):
238+
return self.client.sscan(*args, **kwargs)
239+
240+
@omit_exception
241+
def sscan_iter(self, *args, **kwargs):
242+
return self.client.sscan_iter(*args, **kwargs)
243+
244+
@omit_exception
245+
def smismember(self, *args, **kwargs):
246+
return self.client.smismember(*args, **kwargs)
247+
248+
@omit_exception
249+
def sunion(self, *args, **kwargs):
250+
return self.client.sunion(*args, **kwargs)
251+
252+
@omit_exception
253+
def sunionstore(self, *args, **kwargs):
254+
return self.client.sunionstore(*args, **kwargs)
255+
188256
@omit_exception
189257
def hset(self, *args, **kwargs):
190258
return self.client.hset(*args, **kwargs)

django_redis/client/default.py

+283-2
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,26 @@
33
import socket
44
from collections import OrderedDict
55
from contextlib import suppress
6-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
6+
from typing import (
7+
Any,
8+
Dict,
9+
Iterable,
10+
Iterator,
11+
List,
12+
Optional,
13+
Set,
14+
Tuple,
15+
Union,
16+
cast,
17+
)
718

819
from django.conf import settings
920
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func
1021
from django.core.exceptions import ImproperlyConfigured
1122
from django.utils.module_loading import import_string
1223
from redis import Redis
1324
from redis.exceptions import ConnectionError, ResponseError, TimeoutError
14-
from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT
25+
from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT, PatternT
1526

1627
from django_redis import pool
1728
from django_redis.exceptions import CompressorError, ConnectionInterrupted
@@ -66,6 +77,14 @@ def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None:
6677
def __contains__(self, key: KeyT) -> bool:
6778
return self.has_key(key)
6879

80+
def _has_compression_enabled(self) -> bool:
81+
return (
82+
self._options.get(
83+
"COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor"
84+
)
85+
!= "django_redis.compressors.identity.IdentityCompressor"
86+
)
87+
6988
def get_next_client_index(
7089
self, write: bool = True, tried: Optional[List[int]] = None
7190
) -> int:
@@ -498,6 +517,17 @@ def encode(self, value: EncodableT) -> Union[bytes, int]:
498517

499518
return value
500519

520+
def _decode_iterable_result(
521+
self, result: Any, covert_to_set: bool = True
522+
) -> Union[List[Any], None, Any]:
523+
if result is None:
524+
return None
525+
if isinstance(result, list):
526+
if covert_to_set:
527+
return {self.decode(value) for value in result}
528+
return [self.decode(value) for value in result]
529+
return self.decode(result)
530+
501531
def get_many(
502532
self,
503533
keys: Iterable[KeyT],
@@ -778,6 +808,257 @@ def make_pattern(
778808

779809
return CacheKey(self._backend.key_func(pattern, prefix, version_str))
780810

811+
def sadd(
812+
self,
813+
key: KeyT,
814+
*values: Any,
815+
version: Optional[int] = None,
816+
client: Optional[Redis] = None,
817+
) -> int:
818+
if client is None:
819+
client = self.get_client(write=True)
820+
821+
key = self.make_key(key, version=version)
822+
encoded_values = [self.encode(value) for value in values]
823+
return int(client.sadd(key, *encoded_values))
824+
825+
def scard(
826+
self,
827+
key: KeyT,
828+
version: Optional[int] = None,
829+
client: Optional[Redis] = None,
830+
) -> int:
831+
if client is None:
832+
client = self.get_client(write=False)
833+
834+
key = self.make_key(key, version=version)
835+
return int(client.scard(key))
836+
837+
def sdiff(
838+
self,
839+
*keys: KeyT,
840+
version: Optional[int] = None,
841+
client: Optional[Redis] = None,
842+
) -> Set[Any]:
843+
if client is None:
844+
client = self.get_client(write=False)
845+
846+
nkeys = [self.make_key(key, version=version) for key in keys]
847+
return {self.decode(value) for value in client.sdiff(*nkeys)}
848+
849+
def sdiffstore(
850+
self,
851+
dest: KeyT,
852+
*keys: KeyT,
853+
version_dest: Optional[int] = None,
854+
version_keys: Optional[int] = None,
855+
client: Optional[Redis] = None,
856+
) -> int:
857+
if client is None:
858+
client = self.get_client(write=True)
859+
860+
dest = self.make_key(dest, version=version_dest)
861+
nkeys = [self.make_key(key, version=version_keys) for key in keys]
862+
return int(client.sdiffstore(dest, *nkeys))
863+
864+
def sinter(
865+
self,
866+
*keys: KeyT,
867+
version: Optional[int] = None,
868+
client: Optional[Redis] = None,
869+
) -> Set[Any]:
870+
if client is None:
871+
client = self.get_client(write=False)
872+
873+
nkeys = [self.make_key(key, version=version) for key in keys]
874+
return {self.decode(value) for value in client.sinter(*nkeys)}
875+
876+
def sinterstore(
877+
self,
878+
dest: KeyT,
879+
*keys: KeyT,
880+
version: Optional[int] = None,
881+
client: Optional[Redis] = None,
882+
) -> int:
883+
if client is None:
884+
client = self.get_client(write=True)
885+
886+
dest = self.make_key(dest, version=version)
887+
nkeys = [self.make_key(key, version=version) for key in keys]
888+
return int(client.sinterstore(dest, *nkeys))
889+
890+
def smismember(
891+
self,
892+
key: KeyT,
893+
*members,
894+
version: Optional[int] = None,
895+
client: Optional[Redis] = None,
896+
) -> List[bool]:
897+
if client is None:
898+
client = self.get_client(write=False)
899+
900+
key = self.make_key(key, version=version)
901+
encoded_members = [self.encode(member) for member in members]
902+
903+
return [bool(value) for value in client.smismember(key, *encoded_members)]
904+
905+
def sismember(
906+
self,
907+
key: KeyT,
908+
member: Any,
909+
version: Optional[int] = None,
910+
client: Optional[Redis] = None,
911+
) -> bool:
912+
if client is None:
913+
client = self.get_client(write=False)
914+
915+
key = self.make_key(key, version=version)
916+
member = self.encode(member)
917+
return bool(client.sismember(key, member))
918+
919+
def smembers(
920+
self,
921+
key: KeyT,
922+
version: Optional[int] = None,
923+
client: Optional[Redis] = None,
924+
) -> Set[Any]:
925+
if client is None:
926+
client = self.get_client(write=False)
927+
928+
key = self.make_key(key, version=version)
929+
return {self.decode(value) for value in client.smembers(key)}
930+
931+
def smove(
932+
self,
933+
source: KeyT,
934+
destination: KeyT,
935+
member: Any,
936+
version: Optional[int] = None,
937+
client: Optional[Redis] = None,
938+
) -> bool:
939+
if client is None:
940+
client = self.get_client(write=True)
941+
942+
source = self.make_key(source, version=version)
943+
destination = self.make_key(destination)
944+
member = self.encode(member)
945+
return bool(client.smove(source, destination, member))
946+
947+
def spop(
948+
self,
949+
key: KeyT,
950+
count: Optional[int] = None,
951+
version: Optional[int] = None,
952+
client: Optional[Redis] = None,
953+
) -> Union[Set, Any]:
954+
if client is None:
955+
client = self.get_client(write=True)
956+
957+
nkey = self.make_key(key, version=version)
958+
result = client.spop(nkey, count)
959+
return self._decode_iterable_result(result)
960+
961+
def srandmember(
962+
self,
963+
key: KeyT,
964+
count: Optional[int] = None,
965+
version: Optional[int] = None,
966+
client: Optional[Redis] = None,
967+
) -> Union[List, Any]:
968+
if client is None:
969+
client = self.get_client(write=False)
970+
971+
key = self.make_key(key, version=version)
972+
result = client.srandmember(key, count)
973+
return self._decode_iterable_result(result, covert_to_set=False)
974+
975+
def srem(
976+
self,
977+
key: KeyT,
978+
*members: EncodableT,
979+
version: Optional[int] = None,
980+
client: Optional[Redis] = None,
981+
) -> int:
982+
if client is None:
983+
client = self.get_client(write=True)
984+
985+
key = self.make_key(key, version=version)
986+
nmembers = [self.encode(member) for member in members]
987+
return int(client.srem(key, *nmembers))
988+
989+
def sscan(
990+
self,
991+
key: KeyT,
992+
match: Optional[str] = None,
993+
count: Optional[int] = 10,
994+
version: Optional[int] = None,
995+
client: Optional[Redis] = None,
996+
) -> Set[Any]:
997+
if self._has_compression_enabled() and match:
998+
err_msg = "Using match with compression is not supported."
999+
raise ValueError(err_msg)
1000+
1001+
if client is None:
1002+
client = self.get_client(write=False)
1003+
1004+
key = self.make_key(key, version=version)
1005+
1006+
cursor, result = client.sscan(
1007+
key,
1008+
match=cast(PatternT, self.encode(match)) if match else None,
1009+
count=count,
1010+
)
1011+
return {self.decode(value) for value in result}
1012+
1013+
def sscan_iter(
1014+
self,
1015+
key: KeyT,
1016+
match: Optional[str] = None,
1017+
count: Optional[int] = 10,
1018+
version: Optional[int] = None,
1019+
client: Optional[Redis] = None,
1020+
) -> Iterator[Any]:
1021+
if self._has_compression_enabled() and match:
1022+
err_msg = "Using match with compression is not supported."
1023+
raise ValueError(err_msg)
1024+
1025+
if client is None:
1026+
client = self.get_client(write=False)
1027+
1028+
key = self.make_key(key, version=version)
1029+
for value in client.sscan_iter(
1030+
key,
1031+
match=cast(PatternT, self.encode(match)) if match else None,
1032+
count=count,
1033+
):
1034+
yield self.decode(value)
1035+
1036+
def sunion(
1037+
self,
1038+
*keys: KeyT,
1039+
version: Optional[int] = None,
1040+
client: Optional[Redis] = None,
1041+
) -> Set[Any]:
1042+
if client is None:
1043+
client = self.get_client(write=False)
1044+
1045+
nkeys = [self.make_key(key, version=version) for key in keys]
1046+
return {self.decode(value) for value in client.sunion(*nkeys)}
1047+
1048+
def sunionstore(
1049+
self,
1050+
destination: Any,
1051+
*keys: KeyT,
1052+
version: Optional[int] = None,
1053+
client: Optional[Redis] = None,
1054+
) -> int:
1055+
if client is None:
1056+
client = self.get_client(write=True)
1057+
1058+
destination = self.make_key(destination, version=version)
1059+
encoded_keys = [self.make_key(key, version=version) for key in keys]
1060+
return int(client.sunionstore(destination, *encoded_keys))
1061+
7811062
def close(self) -> None:
7821063
close_flag = self._options.get(
7831064
"CLOSE_CONNECTION",

0 commit comments

Comments
 (0)