Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion llama-index-core/llama_index/core/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from inspect import signature
from inspect import signature, Parameter
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -40,6 +40,12 @@ def create_schema_from_function(
for param_name in params:
if param_name in ignore_fields:
continue
# Skip *args and **kwargs — they can't be modelled as fixed fields
if params[param_name].kind in (
Parameter.VAR_POSITIONAL,
Parameter.VAR_KEYWORD,
):
continue

param_type = params[param_name].annotation
param_default = params[param_name].default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,17 @@ def get_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]:
collection_kv_dict = {}
for key, val_str in self._redis_client.hscan_iter(name=collection):
value = dict(json.loads(val_str))
collection_kv_dict[key.decode()] = value
# key may be bytes or str depending on decode_responses
collection_kv_dict[key.decode() if isinstance(key, bytes) else key] = value
return collection_kv_dict

async def aget_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]:
"""Get all values from the store."""
collection_kv_dict = {}
async for key, val_str in self._async_redis_client.hscan_iter(name=collection):
value = dict(json.loads(val_str))
collection_kv_dict[key.decode()] = value
# key may be bytes or str depending on decode_responses
collection_kv_dict[key.decode() if isinstance(key, bytes) else key] = value
return collection_kv_dict

def delete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,88 @@
from llama_index.core.storage.kvstore.types import BaseKVStore
from unittest.mock import MagicMock, AsyncMock, patch

import pytest

from llama_index.storage.kvstore.redis import RedisKVStore


def test_class():
names_of_base_classes = [b.__name__ for b in RedisKVStore.__mro__]
assert BaseKVStore.__name__ in names_of_base_classes
def test_get_all_decode_responses_false():
"""Test get_all when Redis returns bytes keys (decode_responses=False)."""
mock_redis = MagicMock()
# Simulate hscan_iter returning bytes keys (default behavior)
mock_redis.hscan_iter.return_value = iter([
(b"key1", '{"name": "alice"}'),
(b"key2", '{"name": "bob"}'),
])

store = RedisKVStore(redis_client=mock_redis)
result = store.get_all()

assert result == {"key1": {"name": "alice"}, "key2": {"name": "bob"}}
mock_redis.hscan_iter.assert_called_once()


def test_get_all_decode_responses_true():
"""Test get_all when Redis returns string keys (decode_responses=True)."""
mock_redis = MagicMock()
# Simulate hscan_iter returning string keys (decode_responses=True)
mock_redis.hscan_iter.return_value = iter([
("key1", '{"name": "alice"}'),
("key2", '{"name": "bob"}'),
])

store = RedisKVStore(redis_client=mock_redis)
result = store.get_all()

assert result == {"key1": {"name": "alice"}, "key2": {"name": "bob"}}
mock_redis.hscan_iter.assert_called_once()


def test_get_all_mixed_keys():
"""Test get_all with mixed bytes/string keys (edge case sanity check)."""
mock_redis = MagicMock()
mock_redis.hscan_iter.return_value = iter([
(b"bytes_key", '{"type": "bytes"}'),
("str_key", '{"type": "str"}'),
])

store = RedisKVStore(redis_client=mock_redis)
result = store.get_all()

assert result == {
"bytes_key": {"type": "bytes"},
"str_key": {"type": "str"},
}


@pytest.mark.asyncio
async def test_aget_all_decode_responses_false():
"""Test aget_all when async Redis returns bytes keys (decode_responses=False)."""
mock_async_redis = AsyncMock()
# Simulate async hscan_iter returning bytes keys
async def async_iter():
for item in [(b"akey1", '{"id": 1}'), (b"akey2", '{"id": 2}')]:
yield item

mock_async_redis.hscan_iter.return_value = async_iter()

store = RedisKVStore(async_redis_client=mock_async_redis)
result = await store.aget_all()

assert result == {"akey1": {"id": 1}, "akey2": {"id": 2}}


@pytest.mark.asyncio
async def test_aget_all_decode_responses_true():
"""Test aget_all when async Redis returns string keys (decode_responses=True)."""
mock_async_redis = AsyncMock()
# Simulate async hscan_iter returning string keys
async def async_iter():
for item in [("akey1", '{"id": 1}'), ("akey2", '{"id": 2}')]:
yield item

mock_async_redis.hscan_iter.return_value = async_iter()

store = RedisKVStore(async_redis_client=mock_async_redis)
result = await store.aget_all()

assert result == {"akey1": {"id": 1}, "akey2": {"id": 2}}