Skip to content

Commit 372e5ec

Browse files
committed
Fix DeprecationWarning caused by mixing valkey and redis classes
The `DeprecationWarning` is: ``` DeprecationWarning: Call to 'get_connection' function with deprecated usage of input argument/s '['command_name']'. (Use get_connection() without args instead) -- Deprecated since version 5.3.0. ```
1 parent 9bca96a commit 372e5ec

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

fakeredis/_connection.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,27 @@
1111
from ._typing import Self, lib_version, RaiseErrorTypes, VersionType, ServerType
1212

1313

14-
class FakeConnection(FakeBaseConnectionMixin, redis.Connection):
14+
class FakeConnection(FakeBaseConnectionMixin):
15+
manifested_connection_classes = {}
16+
17+
def __new__(cls, *args: Any, **kwargs: Any) -> "FakeConnection":
18+
"""Inherit dynamically from the correct Connection class.
19+
20+
Currently, only Valkey is a special case.
21+
"""
22+
23+
connection_class = kwargs.get("connection_class", redis.Connection)
24+
25+
if connection_class not in cls.manifested_connection_classes:
26+
module_name, _, _ = connection_class.__module__.partition(".")
27+
new_class_name = f"Fake{module_name.title()}Connection"
28+
base_class = type(new_class_name, (cls, connection_class), {})
29+
cls.manifested_connection_classes[connection_class] = base_class
30+
else:
31+
base_class = cls.manifested_connection_classes[connection_class]
32+
33+
return object.__new__(base_class)
34+
1535
def __init__(*args: Any, **kwargs: Any) -> None:
1636
FakeBaseConnectionMixin.__init__(*args, **kwargs)
1737

@@ -148,7 +168,14 @@ def __init__(
148168
"client_class": client_class,
149169
}
150170
connection_kwargs.update({arg: kwds[arg] for arg in conn_pool_args if arg in kwds})
151-
kwds["connection_pool"] = redis.connection.ConnectionPool(**connection_kwargs)
171+
if server_type == "valkey":
172+
import valkey.connection
173+
174+
valkey_pool = valkey.connection.ConnectionPool(**connection_kwargs)
175+
kwds["connection_pool"] = valkey_pool
176+
else:
177+
redis_pool = redis.connection.ConnectionPool(**connection_kwargs)
178+
kwds["connection_pool"] = redis_pool
152179
kwds.pop("server", None)
153180
kwds.pop("connected", None)
154181
kwds.pop("version", None)

fakeredis/_valkey.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def _validate_server_type(args_dict: Dict[str, Any]) -> None:
1717
class FakeValkey(FakeRedisMixin, valkey.Valkey):
1818
def __init__(self, *args: Any, **kwargs: Any) -> None:
1919
_validate_server_type(kwargs)
20+
kwargs["connection_class"] = valkey.Connection
2021
super().__init__(*args, **kwargs)
2122

2223
@classmethod
@@ -27,6 +28,7 @@ def from_url(cls, *args: Any, **kwargs: Any) -> Self:
2728
class FakeStrictValkey(FakeRedisMixin, valkey.StrictValkey):
2829
def __init__(self, *args: Any, **kwargs: Any) -> None:
2930
_validate_server_type(kwargs)
31+
kwargs["connection_class"] = valkey.Connection
3032
super(FakeStrictValkey, self).__init__(*args, **kwargs)
3133

3234
@classmethod
@@ -38,6 +40,7 @@ class FakeAsyncValkey(FakeAsyncRedisMixin, valkey.asyncio.Valkey):
3840
def __init__(self, *args: Any, **kwargs: Any) -> None:
3941
kwargs.setdefault("client_class", valkey.asyncio.Valkey)
4042
_validate_server_type(kwargs)
43+
kwargs["connection_class"] = valkey.asyncio.connection.Connection
4144
super(FakeAsyncValkey, self).__init__(*args, **kwargs)
4245

4346
@classmethod

fakeredis/aioredis.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,27 @@ def writelines(self, data: Iterable[Any]) -> None:
111111
self._socket.sendall(chunk) # type:ignore
112112

113113

114-
class FakeConnection(FakeBaseConnectionMixin, redis_async.Connection):
114+
class FakeConnection(FakeBaseConnectionMixin):
115+
manifested_connection_classes = {}
116+
117+
def __new__(cls, *args: Any, **kwargs: Any) -> "FakeConnection":
118+
"""Inherit dynamically from the correct Connection class.
119+
120+
Currently, only Valkey is a special case.
121+
"""
122+
123+
connection_class = kwargs.get("connection_class", redis_async.Connection)
124+
125+
if connection_class not in cls.manifested_connection_classes:
126+
module_name, _, _ = connection_class.__module__.partition(".")
127+
new_class_name = f"Fake{module_name.title()}Connection"
128+
base_class = type(new_class_name, (cls, connection_class), {})
129+
cls.manifested_connection_classes[connection_class] = base_class
130+
else:
131+
base_class = cls.manifested_connection_classes[connection_class]
132+
133+
return object.__new__(base_class)
134+
115135
async def _connect(self) -> None:
116136
if not self._server.connected:
117137
raise redis_async.ConnectionError(msgs.CONNECTION_ERROR_MSG)
@@ -247,7 +267,14 @@ def __init__(
247267
"client_class": client_class,
248268
}
249269
connection_kwargs.update({arg: kwds[arg] for arg in conn_pool_args if arg in kwds})
250-
kwds["connection_pool"] = redis_async.connection.ConnectionPool(**connection_kwargs) # type: ignore
270+
if server_type == "valkey":
271+
import valkey.asyncio.connection
272+
273+
v_pool = valkey.asyncio.connection.ConnectionPool(**connection_kwargs) # type: ignore
274+
kwds["connection_pool"] = v_pool
275+
else:
276+
r_pool = redis_async.connection.ConnectionPool(**connection_kwargs) # type: ignore
277+
kwds["connection_pool"] = r_pool
251278
kwds.pop("server", None)
252279
kwds.pop("connected", None)
253280
kwds.pop("version", None)

0 commit comments

Comments
 (0)