Skip to content

Commit 29cb397

Browse files
authored
Validate passed-in Redis clients (redis#296)
Prior to RedisVL 0.4.0, we validated passed-in Redis clients when the user called `set_client()`. This PR reintroduces similar behavior by validating all clients, whether we created them or not, on first access through the lazy-client mechanism. Closes RAAE-694.
1 parent 38c0a60 commit 29cb397

File tree

4 files changed

+70
-7
lines changed

4 files changed

+70
-7
lines changed

redisvl/index/index.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(
282282
self._connection_kwargs = connection_kwargs or {}
283283
self._lock = threading.Lock()
284284

285+
self._validated_client = False
285286
self._owns_redis_client = redis_client is None
286287
if self._owns_redis_client:
287288
weakref.finalize(self, self.disconnect)
@@ -361,6 +362,12 @@ def _redis_client(self) -> Optional[redis.Redis]:
361362
redis_url=self._redis_url,
362363
**self._connection_kwargs,
363364
)
365+
if not self._validated_client:
366+
RedisConnectionFactory.validate_sync_redis(
367+
self.__redis_client,
368+
self._lib_name,
369+
)
370+
self._validated_client = True
364371
return self.__redis_client
365372

366373
@deprecated_function("connect", "Pass connection parameters in __init__.")
@@ -858,6 +865,7 @@ def __init__(
858865
self._connection_kwargs = connection_kwargs or {}
859866
self._lock = asyncio.Lock()
860867

868+
self._validated_client = False
861869
self._owns_redis_client = redis_client is None
862870
if self._owns_redis_client:
863871
weakref.finalize(self, sync_wrapper(self.disconnect))
@@ -954,9 +962,12 @@ async def _get_client(self) -> aredis.Redis:
954962
self._redis_client = (
955963
await RedisConnectionFactory._get_aredis_connection(**kwargs)
956964
)
965+
if not self._validated_client:
957966
await RedisConnectionFactory.validate_async_redis(
958-
self._redis_client, self._lib_name
967+
self._redis_client,
968+
self._lib_name,
959969
)
970+
self._validated_client = True
960971
return self._redis_client
961972

962973
async def _validate_client(

redisvl/redis/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def validate_modules(
159159
required_modules: List of required modules.
160160
161161
Raises:
162-
ValueError: If required Redis modules are not installed.
162+
RedisModuleVersionError: If required Redis modules are not installed.
163163
"""
164164
required_modules = required_modules or DEFAULT_REQUIRED_MODULES
165165

tests/integration/test_async_search_index.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import warnings
2+
from unittest import mock
23

34
import pytest
45
from redis import Redis as SyncRedis
5-
from redis.asyncio import Redis
6+
from redis.asyncio import Redis as AsyncRedis
67

7-
from redisvl.exceptions import RedisSearchError
8+
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
89
from redisvl.index import AsyncSearchIndex
910
from redisvl.query import VectorQuery
1011
from redisvl.redis.utils import convert_bytes
@@ -172,12 +173,12 @@ async def test_search_index_set_client(client, redis_url, index_schema):
172173
with warnings.catch_warnings():
173174
warnings.filterwarnings("ignore", category=DeprecationWarning)
174175
await async_index.create(overwrite=True, drop=True)
175-
assert isinstance(async_index.client, Redis)
176+
assert isinstance(async_index.client, AsyncRedis)
176177

177178
# Tests deprecated sync -> async conversion behavior
178179
assert isinstance(client, SyncRedis)
179180
await async_index.set_client(client)
180-
assert isinstance(async_index.client, Redis)
181+
assert isinstance(async_index.client, AsyncRedis)
181182

182183
await async_index.disconnect()
183184
assert async_index.client is None
@@ -410,3 +411,28 @@ async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis
410411
await async_index.create(overwrite=True, drop=True)
411412
await async_index.disconnect()
412413
assert async_index._redis_client is None
414+
415+
416+
@pytest.mark.asyncio
417+
async def test_async_search_index_validates_redis_modules(redis_url):
418+
"""
419+
A regression test for RAAE-694: we should validate that a passed-in
420+
Redis client has the correct modules installed.
421+
"""
422+
client = AsyncRedis.from_url(redis_url)
423+
with mock.patch(
424+
"redisvl.index.index.RedisConnectionFactory.validate_async_redis"
425+
) as mock_validate_async_redis:
426+
mock_validate_async_redis.side_effect = RedisModuleVersionError(
427+
"Required modules not installed"
428+
)
429+
with pytest.raises(RedisModuleVersionError):
430+
index = AsyncSearchIndex(
431+
schema=IndexSchema.from_dict(
432+
{"index": {"name": "my_index"}, "fields": fields}
433+
),
434+
redis_client=client,
435+
)
436+
await index.create(overwrite=True, drop=True)
437+
438+
mock_validate_async_redis.assert_called_once()

tests/integration/test_search_index.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import warnings
2+
from unittest import mock
23

34
import pytest
5+
from redis import Redis
46

5-
from redisvl.exceptions import RedisSearchError
7+
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
68
from redisvl.index import SearchIndex
79
from redisvl.query import VectorQuery
810
from redisvl.redis.utils import convert_bytes
@@ -363,3 +365,27 @@ def test_search_index_that_owns_client_disconnect(index_schema, redis_url):
363365
index.create(overwrite=True, drop=True)
364366
index.disconnect()
365367
assert index.client is None
368+
369+
370+
def test_search_index_validates_redis_modules(redis_url):
371+
"""
372+
A regression test for RAAE-694: we should validate that a passed-in
373+
Redis client has the correct modules installed.
374+
"""
375+
client = Redis.from_url(redis_url)
376+
with mock.patch(
377+
"redisvl.index.index.RedisConnectionFactory.validate_sync_redis"
378+
) as mock_validate_sync_redis:
379+
mock_validate_sync_redis.side_effect = RedisModuleVersionError(
380+
"Required modules not installed"
381+
)
382+
with pytest.raises(RedisModuleVersionError):
383+
index = SearchIndex(
384+
schema=IndexSchema.from_dict(
385+
{"index": {"name": "my_index"}, "fields": fields}
386+
),
387+
redis_client=client,
388+
)
389+
index.create(overwrite=True, drop=True)
390+
391+
mock_validate_sync_redis.assert_called_once()

0 commit comments

Comments
 (0)