|
| 1 | +from typing import Any |
| 2 | + |
| 3 | +import pytest |
| 4 | +import strawberry |
| 5 | +from sqlalchemy import Column, ForeignKey, Integer, String |
| 6 | +from sqlalchemy.ext.asyncio.engine import AsyncEngine |
| 7 | +from sqlalchemy.orm import relationship |
| 8 | +from strawberry.relay.utils import to_base64 |
| 9 | +from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper |
| 10 | +from strawberry_sqlalchemy_mapper.loader import StrawberrySQLAlchemyLoader |
| 11 | + |
| 12 | + |
| 13 | +@pytest.fixture |
| 14 | +def user_and_group_tables(base: Any): |
| 15 | + class User(base): |
| 16 | + __tablename__ = "user" |
| 17 | + id = Column(Integer, autoincrement=True, primary_key=True) |
| 18 | + name = Column(String(50), nullable=False) |
| 19 | + group_id = Column(Integer, ForeignKey("group.id")) |
| 20 | + group = relationship("Group", back_populates="users") |
| 21 | + |
| 22 | + class Group(base): |
| 23 | + __tablename__ = "group" |
| 24 | + id = Column(Integer, autoincrement=True, primary_key=True) |
| 25 | + name = Column(String, nullable=False) |
| 26 | + users = relationship("User", back_populates="group") |
| 27 | + |
| 28 | + return User, Group |
| 29 | + |
| 30 | + |
| 31 | +@pytest.mark.asyncio |
| 32 | +async def test_query_auto_generated_connection( |
| 33 | + base: Any, |
| 34 | + async_engine: AsyncEngine, |
| 35 | + async_sessionmaker, |
| 36 | + user_and_group_tables, |
| 37 | +): |
| 38 | + user_table, group_table = user_and_group_tables |
| 39 | + |
| 40 | + async with async_engine.begin() as conn: |
| 41 | + await conn.run_sync(base.metadata.create_all) |
| 42 | + mapper = StrawberrySQLAlchemyMapper() |
| 43 | + |
| 44 | + global User, Group |
| 45 | + try: |
| 46 | + |
| 47 | + @mapper.type(user_table) |
| 48 | + class User: |
| 49 | + ... |
| 50 | + |
| 51 | + @mapper.type(group_table) |
| 52 | + class Group: |
| 53 | + ... |
| 54 | + |
| 55 | + @strawberry.type |
| 56 | + class Query: |
| 57 | + @strawberry.field |
| 58 | + async def group(self, id: strawberry.ID) -> Group: |
| 59 | + session = async_sessionmaker() |
| 60 | + return await session.get(group_table, int(id)) |
| 61 | + |
| 62 | + schema = strawberry.Schema(query=Query) |
| 63 | + |
| 64 | + query = """\ |
| 65 | + query GetGroup ($id: ID!) { |
| 66 | + group(id: $id) { |
| 67 | + id |
| 68 | + name |
| 69 | + users { |
| 70 | + pageInfo { |
| 71 | + hasNextPage |
| 72 | + hasPreviousPage |
| 73 | + startCursor |
| 74 | + endCursor |
| 75 | + } |
| 76 | + edges { |
| 77 | + node { |
| 78 | + id |
| 79 | + name |
| 80 | + } |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + """ |
| 86 | + |
| 87 | + async with async_sessionmaker(expire_on_commit=False) as session: |
| 88 | + group = group_table(name="Foo Bar") |
| 89 | + user1 = user_table(name="User 1", group=group) |
| 90 | + user2 = user_table(name="User 2", group=group) |
| 91 | + user3 = user_table(name="User 3", group=group) |
| 92 | + session.add_all([group, user1, user2, user3]) |
| 93 | + await session.commit() |
| 94 | + |
| 95 | + result = await schema.execute( |
| 96 | + query, |
| 97 | + variable_values={"id": group.id}, |
| 98 | + context_value={ |
| 99 | + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( |
| 100 | + async_bind_factory=async_sessionmaker |
| 101 | + ) |
| 102 | + }, |
| 103 | + ) |
| 104 | + assert result.errors is None |
| 105 | + assert result.data == { |
| 106 | + "group": { |
| 107 | + "id": group.id, |
| 108 | + "name": "Foo Bar", |
| 109 | + "users": { |
| 110 | + "pageInfo": { |
| 111 | + "hasNextPage": False, |
| 112 | + "hasPreviousPage": False, |
| 113 | + "startCursor": to_base64("arrayconnection", "0"), |
| 114 | + "endCursor": to_base64("arrayconnection", "2"), |
| 115 | + }, |
| 116 | + "edges": [ |
| 117 | + {"node": {"id": user1.id, "name": "User 1"}}, |
| 118 | + {"node": {"id": user2.id, "name": "User 2"}}, |
| 119 | + {"node": {"id": user3.id, "name": "User 3"}}, |
| 120 | + ], |
| 121 | + }, |
| 122 | + }, |
| 123 | + } |
| 124 | + finally: |
| 125 | + del User, Group |
0 commit comments