Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ]
pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, we definitely needed to test against 3.13

redisstack: [ "latest" ]
fail-fast: false
services:
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,7 @@ tests_sync/
# spelling cruft
*.dic

.idea
.idea

# version files
.tool-versions
18 changes: 14 additions & 4 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,14 @@ def outer_type_or_annotation(field: FieldInfo):
return field.annotation.__args__[0] # type: ignore


def _is_numeric_type(type_: Type[Any]) -> bool:
args = get_args(type_)
try:
return any(issubclass(args[0], t) for t in NUMERIC_TYPES)
except TypeError:
return False


def should_index_field(field_info: Union[FieldInfo, PydanticFieldInfo]) -> bool:
# for vector, full text search, and sortable fields, we always have to index
# We could require the user to set index=True, but that would be a breaking change
Expand Down Expand Up @@ -2004,9 +2012,7 @@ def schema_for_type(
field_info, "vector_options", None
)
try:
is_vector = vector_options and any(
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
)
is_vector = vector_options and _is_numeric_type(typ)
except IndexError:
raise RedisModelError(
f"Vector field '{name}' must be annotated as a container type"
Expand Down Expand Up @@ -2104,7 +2110,11 @@ def schema_for_type(
# a proper type, we can pull the type information from the origin of the first argument.
if not isinstance(typ, type):
type_args = typing_get_args(field_info.annotation)
typ = type_args[0].__origin__
typ = (
getattr(type_args[0], "__origin__", type_args[0])
if type_args
else typ
)

# TODO: GEO field
if is_vector and vector_options:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redis-om"
version = "1.0.1-beta"
version = "1.0.2-beta"
description = "Object mappings, and more, for Redis."
authors = ["Redis OSS <[email protected]>"]
maintainers = ["Redis OSS <[email protected]>"]
Expand All @@ -22,6 +22,7 @@ classifiers = [
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
'Programming Language :: Python',
]
include=[
Expand Down
6 changes: 3 additions & 3 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m):
async def test_pagination_queries(members, m):
member1, member2, member3 = members

actual = await m.Member.find(m.Member.last_name == "Brookins").page()
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page()

assert actual == [member1, member2]

actual = await m.Member.find().page(1, 1)
actual = await m.Member.find().sort_by("id").page(1, 1)

assert actual == [member2]

actual = await m.Member.find().page(0, 1)
actual = await m.Member.find().sort_by("id").page(0, 1)

assert actual == [member1]

Expand Down
45 changes: 43 additions & 2 deletions tests/test_knn_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,24 @@ class Meta:

class Member(BaseJsonModel, index=True):
name: str
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
embeddings: list[float] = Field([], vector_options=vector_field_options)
embeddings_score: Optional[float] = None

await Migrator().run()

return Member


@pytest_asyncio.fixture
async def n(key_prefix, redis):
class BaseJsonModel(JsonModel, abc.ABC):
class Meta:
global_key_prefix = key_prefix
database = redis

class Member(BaseJsonModel, index=True):
name: str
nested: list[list[float]] = Field([], vector_options=vector_field_options)
embeddings_score: Optional[float] = None

await Migrator().run()
Expand All @@ -45,7 +62,7 @@ def to_bytes(vectors: list[float]) -> bytes:
async def test_vector_field(m: Type[JsonModel]):
# Create a new instance of the Member model
vectors = [0.3 for _ in range(DIMENSIONS)]
member = m(name="seth", embeddings=[vectors])
member = m(name="seth", embeddings=vectors)

# Save the member to Redis
await member.save()
Expand All @@ -63,3 +80,27 @@ async def test_vector_field(m: Type[JsonModel]):

assert len(members) == 1
assert members[0].embeddings_score is not None


@py_test_mark_asyncio
async def test_nested_vector_field(n: Type[JsonModel]):
# Create a new instance of the Member model
vectors = [0.3 for _ in range(DIMENSIONS)]
member = n(name="seth", nested=[vectors])

# Save the member to Redis
await member.save()

knn = KNNExpression(
k=1,
vector_field=n.nested,
score_field=n.embeddings_score,
reference_vector=to_bytes(vectors),
)

query = n.find(knn=knn)

members = await query.all()

assert len(members) == 1
assert members[0].embeddings_score is not None