diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 00000000..30ed9526 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,26 @@ +name: 🚀 Deploy to PyPI + +on: + push: + tags: + - '*' + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheel and source tarball + run: | + pip install wheel + python setup.py sdist bdist_wheel + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@v1.1.0 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..89f44467 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,19 @@ +name: Deploy Docs + +# Runs on pushes targeting the default branch +on: + push: + branches: [master] + +jobs: + pages: + runs-on: ubuntu-22.04 + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + permissions: + pages: write + id-token: write + steps: + - id: deployment + uses: sphinx-notes/pages@v3 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..099e9177 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,28 @@ +name: Lint + +on: + push: + branches: + - 'master' + pull_request: + branches: + - '*' + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + - name: Run lint 💅 + run: tox + env: + TOXENV: flake8 diff --git a/.github/workflows/manage_issues.yml b/.github/workflows/manage_issues.yml new file mode 100644 index 00000000..5876acb5 --- /dev/null +++ b/.github/workflows/manage_issues.yml @@ -0,0 +1,49 @@ +name: Issue Manager + +on: + schedule: + - cron: "0 0 * * *" + issue_comment: + types: + - created + issues: + types: + - labeled + pull_request_target: + types: + - labeled + workflow_dispatch: + +permissions: + issues: write + pull-requests: write + +concurrency: + group: lock + +jobs: + lock-old-closed-issues: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v4 + with: + issue-inactive-days: '180' + process-only: 'issues' + issue-comment: > + This issue has been automatically locked since there + has not been any recent activity after it was closed. + Please open a new issue for related topics referencing + this issue. + close-labelled-issues: + runs-on: ubuntu-latest + steps: + - uses: tiangolo/issue-manager@0.4.0 + with: + token: ${{ secrets.GITHUB_TOKEN }} + config: > + { + "needs-reply": { + "delay": 2200000, + "message": "This issue was closed due to inactivity. If your request is still relevant, please open a new issue referencing this one and provide all of the requested information." + } + } diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..66fe306b --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,44 @@ +name: Tests + +on: + push: + branches: + - 'master' + pull_request: + branches: + - '*' + +jobs: + test: + runs-on: ubuntu-latest + strategy: + max-parallel: 10 + matrix: + sql-alchemy: [ "1.2", "1.3", "1.4","2.0" ] + python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions + - name: Test with tox + run: tox + env: + SQLALCHEMY: ${{ matrix.sql-alchemy }} + TOXENV: ${{ matrix.toxenv }} + - name: Upload coverage.xml + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.10' }} + uses: actions/upload-artifact@v4 + with: + name: graphene-sqlalchemy-coverage + path: coverage.xml + if-no-files-found: error + - name: Upload coverage.xml to codecov + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.10' }} + uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore index d4f71e35..1c86b9be 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ __pycache__/ .Python env/ .venv/ +venv/ build/ develop-eggs/ dist/ @@ -26,6 +27,7 @@ var/ *.egg-info/ .installed.cfg *.egg +.python-version # PyInstaller # Usually these files are written by a python script from a template @@ -47,6 +49,7 @@ nosetests.xml coverage.xml *,cover .pytest_cache/ +.benchmarks/ # Translations *.mo @@ -67,3 +70,9 @@ target/ # Databases *.sqlite3 .vscode + +# Schema +*.gql + +# mypy cache +.mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 136f8e7a..262e7608 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,25 +1,30 @@ default_language_version: - python: python3.7 + python: python3.8 repos: -- repo: git://github.com/pre-commit/pre-commit-hooks - rev: c8bad492e1b1d65d9126dba3fe3bd49a5a52b9d6 # v2.1.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 hooks: - - id: check-merge-conflict - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer + - id: check-merge-conflict + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer exclude: ^docs/.*$ - - id: trailing-whitespace + - id: trailing-whitespace exclude: README.md -- repo: git://github.com/PyCQA/flake8 - rev: 88caf5ac484f5c09aedc02167c59c66ff0af0068 # 3.7.7 + - repo: https://github.com/pycqa/isort + rev: 5.12.0 hooks: - - id: flake8 -- repo: git://github.com/asottile/seed-isort-config - rev: v1.7.0 + - id: isort + name: isort (python) + - repo: https://github.com/asottile/pyupgrade + rev: v2.37.3 hooks: - - id: seed-isort-config -- repo: git://github.com/pre-commit/mirrors-isort - rev: v4.3.4 + - id: pyupgrade + - repo: https://github.com/psf/black + rev: 22.6.0 hooks: - - id: isort + - id: black + - repo: https://github.com/PyCQA/flake8 + rev: 4.0.0 + hooks: + - id: flake8 diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 39151a5d..00000000 --- a/.travis.yml +++ /dev/null @@ -1,50 +0,0 @@ -language: python -matrix: - include: - # Python 2.7 - - env: TOXENV=py27 - python: 2.7 - # Python 3.5 - - env: TOXENV=py34 - python: 3.4 - # Python 3.5 - - env: TOXENV=py35 - python: 3.5 - # Python 3.6 - - env: TOXENV=py36 - python: 3.6 - # Python 3.7 - - env: TOXENV=py37 - python: 3.7 - dist: xenial - # SQLAlchemy 1.1 - - env: TOXENV=py37-sql11 - python: 3.7 - dist: xenial - # SQLAlchemy 1.2 - - env: TOXENV=py37-sql12 - python: 3.7 - dist: xenial - # SQLAlchemy 1.3 - - env: TOXENV=py37-sql13 - python: 3.7 - dist: xenial - # Pre-commit - - env: TOXENV=pre-commit - python: 3.7 - dist: xenial -install: pip install .[dev] -script: tox -after_success: coveralls -cache: - directories: - - $HOME/.cache/pip - - $HOME/.cache/pre-commit -deploy: - provider: pypi - user: syrusakbary - on: - tags: true - password: - secure: q0ey31cWljGB30l43aEd1KIPuAHRutzmsd2lBb/2zvD79ReBrzvCdFAkH2xcyo4Volk3aazQQTNUIurnTuvBxmtqja0e+gUaO5LdOcokVdOGyLABXh7qhd2kdvbTDWgSwA4EWneLGXn/SjXSe0f3pCcrwc6WDcLAHxtffMvO9gulpYQtUoOqXfMipMOkRD9iDWTJBsSo3trL70X1FHOVr6Yqi0mfkX2Y/imxn6wlTWRz28Ru94xrj27OmUnCv7qcG0taO8LNlUCquNFAr2sZ+l+U/GkQrrM1y+ehPz3pmI0cCCd7SX/7+EG9ViZ07BZ31nk4pgnqjmj3nFwqnCE/4IApGnduqtrMDF63C9TnB1TU8oJmbbUCu4ODwRpBPZMnwzaHsLnrpdrB89/98NtTfujdrh3U5bVB+t33yxrXVh+FjgLYj9PVeDixpFDn6V/Xcnv4BbRMNOhXIQT7a7/5b99RiXBjCk6KRu+Jdu5DZ+3G4Nbr4oim3kZFPUHa555qbzTlwAfkrQxKv3C3OdVJR7eGc9ADsbHyEJbdPNAh/T+xblXTXLS3hPYDvgM+WEGy3CytBDG3JVcXm25ZP96EDWjweJ7MyfylubhuKj/iR1Y1wiHeIsYq9CqRrFQUWL8gFJBfmgjs96xRXXXnvyLtKUKpKw3wFg5cR/6FnLeYZ8k= - distributions: "sdist bdist_wheel" diff --git a/README.md b/README.md index 2ba0d1cb..4e61f96c 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,24 @@ -Please read [UPGRADE-v2.0.md](https://github.com/graphql-python/graphene/blob/master/UPGRADE-v2.0.md) -to learn how to upgrade to Graphene `2.0`. +Version 3.0 is in beta stage. Please read https://github.com/graphql-python/graphene-sqlalchemy/issues/348 to learn about progress and changes in upcoming +beta releases. --- -# ![Graphene Logo](http://graphene-python.org/favicon.png) Graphene-SQLAlchemy [![Build Status](https://travis-ci.org/graphql-python/graphene-sqlalchemy.svg?branch=master)](https://travis-ci.org/graphql-python/graphene-sqlalchemy) [![PyPI version](https://badge.fury.io/py/graphene-sqlalchemy.svg)](https://badge.fury.io/py/graphene-sqlalchemy) [![Coverage Status](https://coveralls.io/repos/graphql-python/graphene-sqlalchemy/badge.svg?branch=master&service=github)](https://coveralls.io/github/graphql-python/graphene-sqlalchemy?branch=master) +# ![Graphene Logo](http://graphene-python.org/favicon.png) Graphene-SQLAlchemy +[![Build Status](https://github.com/graphql-python/graphene-sqlalchemy/workflows/Tests/badge.svg)](https://github.com/graphql-python/graphene-sqlalchemy/actions) +[![PyPI version](https://badge.fury.io/py/graphene-sqlalchemy.svg)](https://badge.fury.io/py/graphene-sqlalchemy) +![GitHub release (latest by date including pre-releases)](https://img.shields.io/github/v/release/graphql-python/graphene-sqlalchemy?color=green&include_prereleases&label=latest) +[![codecov](https://codecov.io/gh/graphql-python/graphene-sqlalchemy/branch/master/graph/badge.svg?token=Zi5S1TikeN)](https://codecov.io/gh/graphql-python/graphene-sqlalchemy) + A [SQLAlchemy](http://www.sqlalchemy.org/) integration for [Graphene](http://graphene-python.org/). ## Installation -For instaling graphene, just run this command in your shell +For installing Graphene, just run this command in your shell. ```bash -pip install "graphene-sqlalchemy>=2.0" +pip install --pre "graphene-sqlalchemy" ``` ## Examples @@ -34,7 +39,7 @@ class UserModel(Base): last_name = Column(String) ``` -To create a GraphQL schema for it you simply have to write the following: +To create a GraphQL schema for it, you simply have to write the following: ```python import graphene @@ -43,10 +48,10 @@ from graphene_sqlalchemy import SQLAlchemyObjectType class User(SQLAlchemyObjectType): class Meta: model = UserModel - # only return specified fields - only_fields = ("name",) - # exclude specified fields - exclude_fields = ("last_name",) + # use `only_fields` to only expose specific fields ie "name" + # only_fields = ("name",) + # use `exclude_fields` to exclude specific fields ie "last_name" + # exclude_fields = ("last_name",) class Query(graphene.ObjectType): users = graphene.List(User) @@ -58,6 +63,21 @@ class Query(graphene.ObjectType): schema = graphene.Schema(query=Query) ``` +We need a database session first: + +```python +from sqlalchemy import (create_engine) +from sqlalchemy.orm import (scoped_session, sessionmaker) + +engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) +db_session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) +# We will need this for querying, Graphene extracts the session from the base. +# Alternatively it can be provided in the GraphQLResolveInfo.context dictionary under context["session"] +Base.query = db_session.query_property() +``` + Then you can simply query the schema: ```python @@ -104,11 +124,11 @@ schema = graphene.Schema(query=Query) ### Full Examples -To learn more check out the following [examples](examples/): +To learn more check out the following [examples](https://github.com/graphql-python/graphene-sqlalchemy/tree/master/examples/): -- [Flask SQLAlchemy example](examples/flask_sqlalchemy) -- [Nameko SQLAlchemy example](examples/nameko_sqlalchemy) +- [Flask SQLAlchemy example](https://github.com/graphql-python/graphene-sqlalchemy/tree/master/examples/flask_sqlalchemy) +- [Nameko SQLAlchemy example](https://github.com/graphql-python/graphene-sqlalchemy/tree/master/examples/nameko_sqlalchemy) ## Contributing -See [CONTRIBUTING.md](/CONTRIBUTING.md) +See [CONTRIBUTING.md](https://github.com/graphql-python/graphene-sqlalchemy/blob/master/CONTRIBUTING.md) diff --git a/README.rst b/README.rst deleted file mode 100644 index d82b8071..00000000 --- a/README.rst +++ /dev/null @@ -1,102 +0,0 @@ -Please read -`UPGRADE-v2.0.md `__ -to learn how to upgrade to Graphene ``2.0``. - --------------- - -|Graphene Logo| Graphene-SQLAlchemy |Build Status| |PyPI version| |Coverage Status| -=================================================================================== - -A `SQLAlchemy `__ integration for -`Graphene `__. - -Installation ------------- - -For instaling graphene, just run this command in your shell - -.. code:: bash - - pip install "graphene-sqlalchemy>=2.0" - -Examples --------- - -Here is a simple SQLAlchemy model: - -.. code:: python - - from sqlalchemy import Column, Integer, String - from sqlalchemy.orm import backref, relationship - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() - - class UserModel(Base): - __tablename__ = 'department' - id = Column(Integer, primary_key=True) - name = Column(String) - last_name = Column(String) - -To create a GraphQL schema for it you simply have to write the -following: - -.. code:: python - - from graphene_sqlalchemy import SQLAlchemyObjectType - - class User(SQLAlchemyObjectType): - class Meta: - model = UserModel - - class Query(graphene.ObjectType): - users = graphene.List(User) - - def resolve_users(self, info): - query = User.get_query(info) # SQLAlchemy query - return query.all() - - schema = graphene.Schema(query=Query) - -Then you can simply query the schema: - -.. code:: python - - query = ''' - query { - users { - name, - lastName - } - } - ''' - result = schema.execute(query, context_value={'session': db_session}) - -To learn more check out the following `examples `__: - -- **Full example**: `Flask SQLAlchemy - example `__ - -Contributing ------------- - -After cloning this repo, ensure dependencies are installed by running: - -.. code:: sh - - python setup.py install - -After developing, the full test suite can be evaluated by running: - -.. code:: sh - - python setup.py test # Use --pytest-args="-v -s" for verbose mode - -.. |Graphene Logo| image:: http://graphene-python.org/favicon.png -.. |Build Status| image:: https://travis-ci.org/graphql-python/graphene-sqlalchemy.svg?branch=master - :target: https://travis-ci.org/graphql-python/graphene-sqlalchemy -.. |PyPI version| image:: https://badge.fury.io/py/graphene-sqlalchemy.svg - :target: https://badge.fury.io/py/graphene-sqlalchemy -.. |Coverage Status| image:: https://coveralls.io/repos/graphql-python/graphene-sqlalchemy/badge.svg?branch=master&service=github - :target: https://coveralls.io/github/graphql-python/graphene-sqlalchemy?branch=master diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 00000000..237cf1b0 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,18 @@ +API Reference +============== + +SQLAlchemyObjectType +-------------------- +.. autoclass:: graphene_sqlalchemy.SQLAlchemyObjectType + +SQLAlchemyInterface +------------------- +.. autoclass:: graphene_sqlalchemy.SQLAlchemyInterface + +ORMField +-------------------- +.. autoclass:: graphene_sqlalchemy.types.ORMField + +SQLAlchemyConnectionField +------------------------- +.. autoclass:: graphene_sqlalchemy.SQLAlchemyConnectionField diff --git a/docs/conf.py b/docs/conf.py index 3fa6391d..1d8830b6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,6 +1,6 @@ import os -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" # -*- coding: utf-8 -*- # @@ -23,7 +23,10 @@ # import os # import sys # sys.path.insert(0, os.path.abspath('.')) +import os +import sys +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. @@ -34,53 +37,53 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.viewcode", ] if not on_rtd: extensions += [ - 'sphinx.ext.githubpages', + "sphinx.ext.githubpages", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Graphene Django' -copyright = u'Graphene 2016' -author = u'Syrus Akbary' +project = "Graphene Django" +copyright = "Graphene 2016" +author = "Syrus Akbary" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = u'1.0' +version = "1.0" # The full version, including alpha/beta/rc tags. -release = u'1.0.dev' +release = "1.0.dev" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: @@ -94,7 +97,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -116,7 +119,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -175,7 +178,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +# html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -255,34 +258,30 @@ # html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'Graphenedoc' +htmlhelp_basename = "Graphenedoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Graphene.tex', u'Graphene Documentation', - u'Syrus Akbary', 'manual'), + (master_doc, "Graphene.tex", "Graphene Documentation", "Syrus Akbary", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -323,8 +322,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'graphene_django', u'Graphene Django Documentation', - [author], 1) + (master_doc, "graphene_django", "Graphene Django Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -338,9 +336,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Graphene-Django', u'Graphene Django Documentation', - author, 'Graphene Django', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "Graphene-Django", + "Graphene Django Documentation", + author, + "Graphene Django", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. @@ -414,7 +418,7 @@ # epub_post_files = [] # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # The depth of the table of contents in toc.ncx. # @@ -446,4 +450,4 @@ # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} diff --git a/docs/examples.rst b/docs/examples.rst index 283a0f5e..2013cfbb 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -13,22 +13,12 @@ Search all Models with Union interfaces = (relay.Node,) - class BookConnection(relay.Connection): - class Meta: - node = Book - - class Author(SQLAlchemyObjectType): class Meta: model = AuthorModel interfaces = (relay.Node,) - class AuthorConnection(relay.Connection): - class Meta: - node = Author - - class SearchResult(graphene.Union): class Meta: types = (Book, Author) @@ -39,8 +29,8 @@ Search all Models with Union search = graphene.List(SearchResult, q=graphene.String()) # List field for search results # Normal Fields - all_books = SQLAlchemyConnectionField(BookConnection) - all_authors = SQLAlchemyConnectionField(AuthorConnection) + all_books = SQLAlchemyConnectionField(Book.connection) + all_authors = SQLAlchemyConnectionField(Author.connection) def resolve_search(self, info, **args): q = args.get("q") # Search query diff --git a/docs/filters.rst b/docs/filters.rst new file mode 100644 index 00000000..ac36803d --- /dev/null +++ b/docs/filters.rst @@ -0,0 +1,213 @@ +======= +Filters +======= + +Starting in graphene-sqlalchemy version 3, the SQLAlchemyConnectionField class implements filtering by default. The query utilizes a ``filter`` keyword to specify a filter class that inherits from ``graphene.InputObjectType``. + +Migrating from graphene-sqlalchemy-filter +--------------------------------------------- + +If like many of us, you have been using |graphene-sqlalchemy-filter|_ to implement filters and would like to use the in-built mechanism here, there are a couple key differences to note. Mainly, in an effort to simplify the generated schema, filter keywords are nested under their respective fields instead of concatenated. For example, the filter partial ``{usernameIn: ["moderator", "cool guy"]}`` would be represented as ``{username: {in: ["moderator", "cool guy"]}}``. + +.. |graphene-sqlalchemy-filter| replace:: ``graphene-sqlalchemy-filter`` +.. _graphene-sqlalchemy-filter: https://github.com/art1415926535/graphene-sqlalchemy-filter + +Further, some of the constructs found in libraries like `DGraph's DQL `_ have been implemented, so if you have created custom implementations for these features, you may want to take a look at the examples below. + + +Example model +------------- + +Take as example a Pet model similar to that in the sorting example. We will use variations on this arrangement for the following examples. + +.. code:: + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + + + class Query(graphene.ObjectType): + allPets = SQLAlchemyConnectionField(PetNode.connection) + + +Simple filter example +--------------------- + +Filters are defined at the object level through the ``BaseTypeFilter`` class. The ``BaseType`` encompasses both Graphene ``ObjectType``\ s and ``Interface``\ s. Each ``BaseTypeFilter`` instance may define fields via ``FieldFilter`` and relationships via ``RelationshipFilter``. Here's a basic example querying a single field on the Pet model: + +.. code:: + + allPets(filter: {name: {eq: "Fido"}}){ + edges { + node { + name + } + } + } + +This will return all pets with the name "Fido". + + +Custom filter types +------------------- + +If you'd like to implement custom behavior for filtering a field, you can do so by extending one of the base filter classes in ``graphene_sqlalchemy.filters``. For example, if you'd like to add a ``divisible_by`` keyword to filter the age attribute on the ``Pet`` model, you can do so as follows: + +.. code:: python + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + ... + + age = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + +Filtering over relationships with RelationshipFilter +---------------------------------------------------- + +When a filter class field refers to another object in a relationship, you may nest filters on relationship object attributes. This happens directly for 1:1 and m:1 relationships and through the ``contains`` and ``containsExactly`` keywords for 1:n and m:n relationships. + + +:1 relationships +^^^^^^^^^^^^^^^^ + +When an object or interface defines a singular relationship, relationship object attributes may be filtered directly like so: + +Take the following SQLAlchemy model definition as an example: + +.. code:: python + + class Pet + ... + person_id = Column(Integer(), ForeignKey("people.id")) + + class Person + ... + pets = relationship("Pet", backref="person") + + +Then, this query will return all pets whose person is named "Ada": + +.. code:: + + allPets(filter: { + person: {name: {eq: "Ada"}} + }) { + ... + } + + +:n relationships +^^^^^^^^^^^^^^^^ + +However, for plural relationships, relationship object attributes must be filtered through either ``contains`` or ``containsExactly``: + +Now, using a many-to-many model definition: + +.. code:: python + + people_pets_table = sqlalchemy.Table( + "people_pets", + Base.metadata, + Column("person_id", ForeignKey("people.id")), + Column("pet_id", ForeignKey("pets.id")), + ) + + class Pet + ... + + class Person + ... + pets = relationship("Pet", backref="people") + + +this query will return all pets which have a person named "Ben" in their ``people`` list. + +.. code:: + + allPets(filter: { + people: { + contains: [{name: {eq: "Ben"}}], + } + }) { + ... + } + + +and this one will return all pets which hvae a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names. + +.. code:: + + allPets(filter: { + articles: { + containsExactly: [ + {name: {eq: "Ada"}}, + {name: {eq: "Ben"}}, + ], + } + }) { + ... + } + +And/Or Logic +------------ + +Filters can also be chained together logically using `and` and `or` keywords nested under `filter`. Clauses are passed directly to `sqlalchemy.and_` and `slqlalchemy.or_`, respectively. To return all pets named "Fido" or "Spot", use: + + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + {name: {eq: "Spot"}}, + ] + }) { + ... + } + +And to return all pets that are named "Fido" or are 5 years old and named "Spot", use: + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + { and: [ + {name: {eq: "Spot"}}, + {age: {eq: 5}} + } + ] + }) { + ... + } + + +Hybrid Property support +----------------------- + +Filtering over SQLAlchemy `hybrid properties `_ is fully supported. + + +Reporting feedback and bugs +--------------------------- + +Filtering is a new feature to graphene-sqlalchemy, so please `post an issue on Github `_ if you run into any problems or have ideas on how to improve the implementation. diff --git a/docs/index.rst b/docs/index.rst index 81b2f316..4245eba8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,6 +6,11 @@ Contents: .. toctree:: :maxdepth: 0 - tutorial + starter + inheritance + relay tips + filters examples + tutorial + api diff --git a/docs/inheritance.rst b/docs/inheritance.rst new file mode 100644 index 00000000..d7fcca9d --- /dev/null +++ b/docs/inheritance.rst @@ -0,0 +1,152 @@ +Inheritance Examples +==================== + + +Create interfaces from inheritance relationships +------------------------------------------------ + +.. note:: + If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_. + +SQLAlchemy has excellent support for class inheritance hierarchies. +These hierarchies can be represented in your GraphQL schema by means +of interfaces_. Much like ObjectTypes, Interfaces in +Graphene-SQLAlchemy are able to infer their fields and relationships +from the attributes of their underlying SQLAlchemy model: + +.. _interfaces: https://docs.graphene-python.org/en/latest/types/interfaces/ + +.. code:: python + + from sqlalchemy import Column, Date, Integer, String + from sqlalchemy.ext.declarative import declarative_base + + import graphene + from graphene import relay + from graphene_sqlalchemy import SQLAlchemyInterface, SQLAlchemyObjectType + + Base = declarative_base() + + class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + } + + class Employee(Person): + hire_date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "employee", + } + + class Customer(Person): + first_purchase_date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "customer", + } + + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (relay.Node, PersonType) + + class CustomerType(SQLAlchemyObjectType): + class Meta: + model = Customer + interfaces = (relay.Node, PersonType) + +Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must +be linked to an abstract Model that does not specify a `polymorphic_identity`, +because we cannot return instances of interfaces from a GraphQL query. +If Person specified a `polymorphic_identity`, instances of Person could +be inserted into and returned by the database, potentially causing +Persons to be returned to the resolvers. + +When querying on the base type, you can refer directly to common fields, +and fields on concrete implementations using the `... on` syntax: + + +.. code:: + + people { + name + birthDate + ... on EmployeeType { + hireDate + } + ... on CustomerType { + firstPurchaseDate + } + } + + +.. danger:: + When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications. + See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`. + +Please note that by default, the "polymorphic_on" column is *not* +generated as a field on types that use polymorphic inheritance, as +this is considered an implementation detail. The idiomatic way to +retrieve the concrete GraphQL type of an object is to query for the +`__typename` field. +To override this behavior, an `ORMField` needs to be created +for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended* +as it promotes abiguous schema design + +If your SQLAlchemy model only specifies a relationship to the +base type, you will need to explicitly pass your concrete implementation +class to the Schema constructor via the `types=` argument: + +.. code:: python + + schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType]) + + +See also: `Graphene Interfaces `_ + + +Eager Loading & Using with AsyncSession +---------------------------------------- + +When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly. +This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables. +To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model: + +.. code:: python + + class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session + } + +Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers: + +.. code:: python + + class Query(graphene.ObjectType): + people = graphene.Field(graphene.List(PersonType)) + + async def resolve_people(self, _info): + return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all() + +Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR. + +For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs `_. diff --git a/docs/relay.rst b/docs/relay.rst new file mode 100644 index 00000000..7b733c76 --- /dev/null +++ b/docs/relay.rst @@ -0,0 +1,43 @@ +Relay +========== + +:code:`graphene-sqlalchemy` comes with pre-defined +connection fields to quickly create a functioning relay API. +Using the :code:`SQLAlchemyConnectionField`, you have access to relay pagination, +sorting and filtering (filtering is coming soon!). + +To be used in a relay connection, your :code:`SQLAlchemyObjectType` must implement +the :code:`Node` interface from :code:`graphene.relay`. This handles the creation of +the :code:`Connection` and :code:`Edge` types automatically. + +The following example creates a relay-paginated connection: + + + +.. code:: python + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(Enum('cat', 'dog', name='pet_kind'), nullable=False) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces=(Node,) + + + class Query(ObjectType): + all_pets = SQLAlchemyConnectionField(PetNode.connection) + +To disable sorting on the connection, you can set :code:`sort` to :code:`None` the +:code:`SQLAlchemyConnectionField`: + + +.. code:: python + + class Query(ObjectType): + all_pets = SQLAlchemyConnectionField(PetNode.connection, sort=None) + diff --git a/docs/requirements.txt b/docs/requirements.txt index 666a8c9d..220b7cfb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ +sphinx # Docs template http://graphene-python.org/sphinx_graphene_theme.zip diff --git a/docs/starter.rst b/docs/starter.rst new file mode 100644 index 00000000..6e09ab00 --- /dev/null +++ b/docs/starter.rst @@ -0,0 +1,118 @@ +Getting Started +================= + +Welcome to the graphene-sqlalchemy documentation! +Graphene is a powerful Python library for building GraphQL APIs, +and SQLAlchemy is a popular ORM (Object-Relational Mapping) +tool for working with databases. When combined, graphene-sqlalchemy +allows developers to quickly and easily create a GraphQL API that +seamlessly interacts with a SQLAlchemy-managed database. +It is fully compatible with SQLAlchemy 1.4 and 2.0. +This documentation provides detailed instructions on how to get +started with graphene-sqlalchemy, including installation, setup, +and usage examples. + +Installation +------------ + +To install :code:`graphene-sqlalchemy`, just run this command in your shell: + +.. code:: bash + + pip install --pre "graphene-sqlalchemy" + +Examples +-------- + +Here is a simple SQLAlchemy model: + +.. code:: python + + from sqlalchemy import Column, Integer, String + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class UserModel(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + name = Column(String) + last_name = Column(String) + +To create a GraphQL schema for it, you simply have to write the +following: + +.. code:: python + + import graphene + from graphene_sqlalchemy import SQLAlchemyObjectType + + class User(SQLAlchemyObjectType): + class Meta: + model = UserModel + # use `only_fields` to only expose specific fields ie "name" + # only_fields = ("name",) + # use `exclude_fields` to exclude specific fields ie "last_name" + # exclude_fields = ("last_name",) + + class Query(graphene.ObjectType): + users = graphene.List(User) + + def resolve_users(self, info): + query = User.get_query(info) # SQLAlchemy query + return query.all() + + schema = graphene.Schema(query=Query) + +Then you can simply query the schema: + +.. code:: python + + query = ''' + query { + users { + name, + lastName + } + } + ''' + result = schema.execute(query, context_value={'session': db_session}) + + +It is important to provide a session for graphene-sqlalchemy to resolve the models. +In this example, it is provided using the GraphQL context. See :doc:`tips` for +other ways to implement this. + +You may also subclass SQLAlchemyObjectType by providing +``abstract = True`` in your subclasses Meta: + +.. code:: python + + from graphene_sqlalchemy import SQLAlchemyObjectType + + class ActiveSQLAlchemyObjectType(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def get_node(cls, info, id): + return cls.get_query(info).filter( + and_(cls._meta.model.deleted_at==None, + cls._meta.model.id==id) + ).first() + + class User(ActiveSQLAlchemyObjectType): + class Meta: + model = UserModel + + class Query(graphene.ObjectType): + users = graphene.List(User) + + def resolve_users(self, info): + query = User.get_query(info) # SQLAlchemy query + return query.all() + + schema = graphene.Schema(query=Query) + +More complex inhertiance using SQLAlchemy's polymorphic models is also supported. +You can check out :doc:`inheritance` for a guide. diff --git a/docs/tips.rst b/docs/tips.rst index 1fd39107..a3ed69ed 100644 --- a/docs/tips.rst +++ b/docs/tips.rst @@ -4,6 +4,7 @@ Tips Querying -------- +.. _querying: In order to make querying against the database work, there are two alternatives: @@ -50,13 +51,8 @@ Given the model model = Pet - class PetConnection(Connection): - class Meta: - node = PetNode - - class Query(ObjectType): - allPets = SQLAlchemyConnectionField(PetConnection) + allPets = SQLAlchemyConnectionField(PetNode.connection) some of the allowed queries are diff --git a/docs/tutorial.rst b/docs/tutorial.rst index bc5ee62d..3c4c135e 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -102,28 +102,18 @@ Create ``flask_sqlalchemy/schema.py`` and type the following: interfaces = (relay.Node, ) - class DepartmentConnection(relay.Connection): - class Meta: - node = Department - - class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel interfaces = (relay.Node, ) - class EmployeeConnection(relay.Connection): - class Meta: - node = Employee - - class Query(graphene.ObjectType): node = relay.Node.Field() # Allows sorting over multiple columns, by default over the primary key - all_employees = SQLAlchemyConnectionField(EmployeeConnection) + all_employees = SQLAlchemyConnectionField(Employee.connection) # Disable sorting over this field - all_departments = SQLAlchemyConnectionField(DepartmentConnection, sort=None) + all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) schema = graphene.Schema(query=Query) diff --git a/examples/filters/README.md b/examples/filters/README.md new file mode 100644 index 00000000..a72e75de --- /dev/null +++ b/examples/filters/README.md @@ -0,0 +1,47 @@ +Example Filters Project +================================ + +This example highlights the ability to filter queries in graphene-sqlalchemy. + +The project contains two models, one named `Department` and another +named `Employee`. + +Getting started +--------------- + +First you'll need to get the source of the project. Do this by cloning the +whole Graphene-SQLAlchemy repository: + +```bash +# Get the example project code +git clone https://github.com/graphql-python/graphene-sqlalchemy.git +cd graphene-sqlalchemy/examples/filters +``` + +It is recommended to create a virtual environment +for this project. We'll do this using +[virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) +to keep things simple, +but you may also find something like +[virtualenvwrapper](https://virtualenvwrapper.readthedocs.org/en/latest/) +to be useful: + +```bash +# Create a virtualenv in which we can install the dependencies +virtualenv env +source env/bin/activate +``` + +Install our dependencies: + +```bash +pip install -r requirements.txt +``` + +The following command will setup the database, and start the server: + +```bash +python app.py +``` + +Now head over to your favorite GraphQL client, POST to [http://127.0.0.1:5000/graphql](http://127.0.0.1:5000/graphql) and run some queries! diff --git a/examples/filters/__init__.py b/examples/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/filters/app.py b/examples/filters/app.py new file mode 100644 index 00000000..ab918da7 --- /dev/null +++ b/examples/filters/app.py @@ -0,0 +1,16 @@ +from database import init_db +from fastapi import FastAPI +from schema import schema +from starlette_graphene3 import GraphQLApp, make_playground_handler + + +def create_app() -> FastAPI: + init_db() + app = FastAPI() + + app.mount("/graphql", GraphQLApp(schema, on_get=make_playground_handler())) + + return app + + +app = create_app() diff --git a/examples/filters/database.py b/examples/filters/database.py new file mode 100644 index 00000000..8f6522f7 --- /dev/null +++ b/examples/filters/database.py @@ -0,0 +1,49 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, echo=True +) +session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +from sqlalchemy.orm import scoped_session as scoped_session_factory + +scoped_session = scoped_session_factory(session_factory) + +Base.query = scoped_session.query_property() +Base.metadata.bind = engine + + +def init_db(): + from models import Person, Pet, Toy + + Base.metadata.create_all() + scoped_session.execute("PRAGMA foreign_keys=on") + db = scoped_session() + + person1 = Person(name="A") + person2 = Person(name="B") + + pet1 = Pet(name="Spot") + pet2 = Pet(name="Milo") + + toy1 = Toy(name="disc") + toy2 = Toy(name="ball") + + person1.pet = pet1 + person2.pet = pet2 + + pet1.toys.append(toy1) + pet2.toys.append(toy1) + pet2.toys.append(toy2) + + db.add(person1) + db.add(person2) + db.add(pet1) + db.add(pet2) + db.add(toy1) + db.add(toy2) + + db.commit() diff --git a/examples/filters/models.py b/examples/filters/models.py new file mode 100644 index 00000000..1b22956b --- /dev/null +++ b/examples/filters/models.py @@ -0,0 +1,34 @@ +import sqlalchemy +from database import Base +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + person_id = Column(Integer(), ForeignKey("people.id")) + + +class Person(Base): + __tablename__ = "people" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + pets = relationship("Pet", backref="person") + + +pets_toys_table = sqlalchemy.Table( + "pets_toys", + Base.metadata, + Column("pet_id", ForeignKey("pets.id")), + Column("toy_id", ForeignKey("toys.id")), +) + + +class Toy(Base): + __tablename__ = "toys" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pets = relationship("Pet", secondary=pets_toys_table, backref="toys") diff --git a/examples/filters/requirements.txt b/examples/filters/requirements.txt new file mode 100644 index 00000000..b433ec59 --- /dev/null +++ b/examples/filters/requirements.txt @@ -0,0 +1,3 @@ +-e ../../ +fastapi +uvicorn diff --git a/examples/filters/run.sh b/examples/filters/run.sh new file mode 100755 index 00000000..ec365444 --- /dev/null +++ b/examples/filters/run.sh @@ -0,0 +1 @@ +uvicorn app:app --port 5000 diff --git a/examples/filters/schema.py b/examples/filters/schema.py new file mode 100644 index 00000000..2728cab7 --- /dev/null +++ b/examples/filters/schema.py @@ -0,0 +1,42 @@ +from models import Person as PersonModel +from models import Pet as PetModel +from models import Toy as ToyModel + +import graphene +from graphene import relay +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene_sqlalchemy.fields import SQLAlchemyConnectionField + + +class Pet(SQLAlchemyObjectType): + class Meta: + model = PetModel + name = "Pet" + interfaces = (relay.Node,) + batching = True + + +class Person(SQLAlchemyObjectType): + class Meta: + model = PersonModel + name = "Person" + interfaces = (relay.Node,) + batching = True + + +class Toy(SQLAlchemyObjectType): + class Meta: + model = ToyModel + name = "Toy" + interfaces = (relay.Node,) + batching = True + + +class Query(graphene.ObjectType): + node = relay.Node.Field() + pets = SQLAlchemyConnectionField(Pet.connection) + people = SQLAlchemyConnectionField(Person.connection) + toys = SQLAlchemyConnectionField(Toy.connection) + + +schema = graphene.Schema(query=Query) diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index ca4d4122..74ec7ca9 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -2,10 +2,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker -engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) +engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() @@ -15,24 +15,25 @@ def init_db(): # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() from models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) # Create the fixtures - engineering = Department(name='Engineering') + engineering = Department(name="Engineering") db_session.add(engineering) - hr = Department(name='Human Resources') + hr = Department(name="Human Resources") db_session.add(hr) - manager = Role(name='manager') + manager = Role(name="manager") db_session.add(manager) - engineer = Role(name='engineer') + engineer = Role(name="engineer") db_session.add(engineer) - peter = Employee(name='Peter', department=engineering, role=engineer) + peter = Employee(name="Peter", department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering, role=engineer) + roy = Employee(name="Roy", department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr, role=manager) + tracy = Employee(name="Tracy", department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index efbbe690..38f0fd0a 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -4,35 +4,31 @@ class Department(Base): - __tablename__ = 'department' + __tablename__ = "department" id = Column(Integer, primary_key=True) name = Column(String) class Role(Base): - __tablename__ = 'roles' + __tablename__ = "roles" role_id = Column(Integer, primary_key=True) name = Column(String) class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String) # Use default=func.now() to set the default hiring time # of an Employee to be the current time when an # Employee record was created hired_on = Column(DateTime, default=func.now()) - department_id = Column(Integer, ForeignKey('department.id')) - role_id = Column(Integer, ForeignKey('roles.role_id')) + department_id = Column(Integer, ForeignKey("department.id")) + role_id = Column(Integer, ForeignKey("roles.role_id")) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( - Department, - backref=backref('employees', - uselist=True, - cascade='delete,all')) + Department, backref=backref("employees", uselist=True, cascade="delete,all") + ) role = relationship( - Role, - backref=backref('roles', - uselist=True, - cascade='delete,all')) + Role, backref=backref("roles", uselist=True, cascade="delete,all") + ) diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index 9ed09464..c4a91e63 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -10,30 +10,31 @@ class Department(SQLAlchemyObjectType): class Meta: model = DepartmentModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Role(SQLAlchemyObjectType): class Meta: model = RoleModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Query(graphene.ObjectType): node = relay.Node.Field() # Allow only single column sorting all_employees = SQLAlchemyConnectionField( - Employee, sort=Employee.sort_argument()) + Employee.connection, sort=Employee.sort_argument() + ) # Allows sorting over multiple columns, by default over the primary key - all_roles = SQLAlchemyConnectionField(Role) + all_roles = SQLAlchemyConnectionField(Role.connection) # Disable sorting over this field - all_departments = SQLAlchemyConnectionField(Department, sort=None) + all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) -schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) +schema = graphene.Schema(query=Query) diff --git a/examples/nameko_sqlalchemy/README.md b/examples/nameko_sqlalchemy/README.md index 6302cb33..e0803895 100644 --- a/examples/nameko_sqlalchemy/README.md +++ b/examples/nameko_sqlalchemy/README.md @@ -46,7 +46,6 @@ Now the following command will setup the database, and start the server: ```bash ./run.sh - ``` Now head on over to postman and send POST request to: diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py index 05352529..64d305ea 100755 --- a/examples/nameko_sqlalchemy/app.py +++ b/examples/nameko_sqlalchemy/app.py @@ -1,37 +1,45 @@ from database import db_session, init_db from schema import schema -from graphql_server import (HttpQueryError, default_format_error, - encode_execution_results, json_encode, - load_json_body, run_http_query) - - -class App(): - def __init__(self): - init_db() - - def query(self, request): - data = self.parse_body(request) - execution_results, params = run_http_query( - schema, - 'post', - data) - result, status_code = encode_execution_results( - execution_results, - format_error=default_format_error,is_batch=False, encode=json_encode) - return result - - def parse_body(self,request): - # We use mimetype here since we don't need the other - # information provided by content_type - content_type = request.mimetype - if content_type == 'application/graphql': - return {'query': request.data.decode('utf8')} - - elif content_type == 'application/json': - return load_json_body(request.data.decode('utf8')) - - elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'): - return request.form - - return {} +from graphql_server import ( + HttpQueryError, + default_format_error, + encode_execution_results, + json_encode, + load_json_body, + run_http_query, +) + + +class App: + def __init__(self): + init_db() + + def query(self, request): + data = self.parse_body(request) + execution_results, params = run_http_query(schema, "post", data) + result, status_code = encode_execution_results( + execution_results, + format_error=default_format_error, + is_batch=False, + encode=json_encode, + ) + return result + + def parse_body(self, request): + # We use mimetype here since we don't need the other + # information provided by content_type + content_type = request.mimetype + if content_type == "application/graphql": + return {"query": request.data.decode("utf8")} + + elif content_type == "application/json": + return load_json_body(request.data.decode("utf8")) + + elif content_type in ( + "application/x-www-form-urlencoded", + "multipart/form-data", + ): + return request.form + + return {} diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index 01e76ca6..74ec7ca9 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -2,10 +2,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker -engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) +engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() @@ -14,25 +14,26 @@ def init_db(): # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from .models import Department, Employee, Role + from models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) # Create the fixtures - engineering = Department(name='Engineering') + engineering = Department(name="Engineering") db_session.add(engineering) - hr = Department(name='Human Resources') + hr = Department(name="Human Resources") db_session.add(hr) - manager = Role(name='manager') + manager = Role(name="manager") db_session.add(manager) - engineer = Role(name='engineer') + engineer = Role(name="engineer") db_session.add(engineer) - peter = Employee(name='Peter', department=engineering, role=engineer) + peter = Employee(name="Peter", department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering, role=engineer) + roy = Employee(name="Roy", department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr, role=manager) + tracy = Employee(name="Tracy", department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/nameko_sqlalchemy/models.py b/examples/nameko_sqlalchemy/models.py index efbbe690..38f0fd0a 100644 --- a/examples/nameko_sqlalchemy/models.py +++ b/examples/nameko_sqlalchemy/models.py @@ -4,35 +4,31 @@ class Department(Base): - __tablename__ = 'department' + __tablename__ = "department" id = Column(Integer, primary_key=True) name = Column(String) class Role(Base): - __tablename__ = 'roles' + __tablename__ = "roles" role_id = Column(Integer, primary_key=True) name = Column(String) class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String) # Use default=func.now() to set the default hiring time # of an Employee to be the current time when an # Employee record was created hired_on = Column(DateTime, default=func.now()) - department_id = Column(Integer, ForeignKey('department.id')) - role_id = Column(Integer, ForeignKey('roles.role_id')) + department_id = Column(Integer, ForeignKey("department.id")) + role_id = Column(Integer, ForeignKey("roles.role_id")) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( - Department, - backref=backref('employees', - uselist=True, - cascade='delete,all')) + Department, backref=backref("employees", uselist=True, cascade="delete,all") + ) role = relationship( - Role, - backref=backref('roles', - uselist=True, - cascade='delete,all')) + Role, backref=backref("roles", uselist=True, cascade="delete,all") + ) diff --git a/examples/nameko_sqlalchemy/schema.py b/examples/nameko_sqlalchemy/schema.py index a33cab9b..ced300b3 100644 --- a/examples/nameko_sqlalchemy/schema.py +++ b/examples/nameko_sqlalchemy/schema.py @@ -8,31 +8,28 @@ class Department(SQLAlchemyObjectType): - class Meta: model = DepartmentModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Employee(SQLAlchemyObjectType): - class Meta: model = EmployeeModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Role(SQLAlchemyObjectType): - class Meta: model = RoleModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) - all_roles = SQLAlchemyConnectionField(Role) + all_employees = SQLAlchemyConnectionField(Employee.connection) + all_roles = SQLAlchemyConnectionField(Role.connection) role = graphene.Field(Role) -schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) +schema = graphene.Schema(query=Query) diff --git a/examples/nameko_sqlalchemy/service.py b/examples/nameko_sqlalchemy/service.py index d9c519c9..7f4c5078 100644 --- a/examples/nameko_sqlalchemy/service.py +++ b/examples/nameko_sqlalchemy/service.py @@ -4,8 +4,8 @@ class DepartmentService: - name = 'department' + name = "department" - @http('POST', '/graphql') + @http("POST", "/graphql") def query(self, request): return App().query(request) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index d8eb010b..69bb79bb 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,11 +1,12 @@ -from .types import SQLAlchemyObjectType from .fields import SQLAlchemyConnectionField +from .types import SQLAlchemyInterface, SQLAlchemyObjectType from .utils import get_query, get_session -__version__ = "2.2.0" +__version__ = "3.0.0rc2" __all__ = [ "__version__", + "SQLAlchemyInterface", "SQLAlchemyObjectType", "SQLAlchemyConnectionField", "get_query", diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py new file mode 100644 index 00000000..731d7645 --- /dev/null +++ b/graphene_sqlalchemy/batching.py @@ -0,0 +1,143 @@ +"""The dataloader uses "select in loading" strategy to load related entities.""" +from asyncio import get_event_loop +from typing import Any, Dict + +import sqlalchemy +from sqlalchemy.orm import Session, strategies +from sqlalchemy.orm.query import QueryContext +from sqlalchemy.util import immutabledict + +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, + is_graphene_version_less_than, +) + + +def get_data_loader_impl() -> Any: # pragma: no cover + """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, + aiodataloader is used in conjunction with older versions of graphene""" + if is_graphene_version_less_than("3.1.1"): + from aiodataloader import DataLoader + else: + from graphene.utils.dataloader import DataLoader + + return DataLoader + + +DataLoader = get_data_loader_impl() + + +class RelationshipLoader(DataLoader): + cache = False + + def __init__(self, relationship_prop, selectin_loader): + super().__init__() + self.relationship_prop = relationship_prop + self.selectin_loader = selectin_loader + + async def batch_load_fn(self, parents): + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = self.relationship_prop.mapper + parent_mapper = self.relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = None + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + parent_mapper_query = session.query(parent_mapper.entity) + query_context = parent_mapper_query._compile_context() + else: + query_context = QueryContext(session.query(parent_mapper.entity)) + if SQL_VERSION_HIGHER_EQUAL_THAN_2: # pragma: no cover + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + None, # recursion depth can be none + immutabledict(), # default value for selectinload->lazyload + ) + elif SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + ) + else: + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + return [getattr(parent, self.relationship_prop.key) for parent in parents] + + +# Cache this across `batch_load_fn` calls +# This is so SQL string generation is cached under-the-hood via `bakery` +# Caching the relationship loader for each relationship prop. +RELATIONSHIP_LOADERS_CACHE: Dict[ + sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader +] = {} + + +def get_batch_resolver(relationship_prop): + """Get the resolve function for the given relationship.""" + + def _get_loader(relationship_prop): + """Retrieve the cached loader of the given relationship.""" + loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) + if loader is None or loader.loop != get_event_loop(): + selectin_loader = strategies.SelectInLoader( + relationship_prop, (("lazy", "selectin"),) + ) + loader = RelationshipLoader( + relationship_prop=relationship_prop, + selectin_loader=selectin_loader, + ) + RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader + return loader + + async def resolve(root, info, **args): + return await _get_loader(relationship_prop).load(root) + + return resolve diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 9466cbaf..6502412f 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,19 +1,101 @@ -from singledispatch import singledispatch -from sqlalchemy import types +import datetime +import sys +import typing +import uuid +from decimal import Decimal +from typing import Any, Dict, Optional, TypeVar, Union, cast + +from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import interfaces - -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - String) +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ( + ColumnProperty, + RelationshipProperty, + class_mapper, + interfaces, + strategies, +) + +import graphene from graphene.types.json import JSONString +from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .registry import get_global_registry +from .registry import Registry, get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + DummyImport, + column_type_eq, + registry_sqlalchemy_model_from_str, + safe_isinstance, + safe_issubclass, + singledispatchbymatchfunction, +) + +# Import path changed in 1.4 +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.orm import DeclarativeMeta +else: + from sqlalchemy.ext.declarative import DeclarativeMeta + +# We just use MapperProperties for type hints, they don't exist in sqlalchemy < 1.4 +try: + from sqlalchemy import MapperProperty +except ImportError: + # sqlalchemy < 1.4 + MapperProperty = Any + +try: + from typing import ForwardRef +except ImportError: + # python 3.6 + from typing import _ForwardRef as ForwardRef try: - from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType + from sqlalchemy_utils.types.choice import EnumTypeImpl except ImportError: - ChoiceType = JSONType = ScalarListType = TSVectorType = object + EnumTypeImpl = object + +try: + import sqlalchemy_utils as sqa_utils +except ImportError: + sqa_utils = DummyImport() + +is_selectin_available = getattr(strategies, "SelectInLoader", None) + +""" +Flag for whether to generate stricter non-null fields for many-relationships. + +For many-relationships, both the list element and the list field itself will be +non-null by default. This better matches ORM semantics, where there is always a +list for a many relationship (even if it is empty), and it never contains None. + +This option can be set to False to revert to pre-3.0 behavior. + +For example, given a User model with many Comments: + + class User(Base): + comments = relationship("Comment") + +The Schema will be: + + type User { + comments: [Comment!]! + } + +When set to False, the pre-3.0 behavior gives: + + type User { + comments: [Comment] + } +""" +use_non_null_many_relationships = True + + +def set_non_null_many_relationships(non_null_flag): + global use_non_null_many_relationships + use_non_null_many_relationships = non_null_flag def get_column_doc(column): @@ -24,43 +106,192 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): - direction = relationship.direction - model = relationship.mapper.entity +def convert_sqlalchemy_association_proxy( + parent, + assoc_prop, + obj_type, + registry, + connection_field_factory, + batching, + resolver, + **field_kwargs, +): + def dynamic_type(): + prop = class_mapper(parent).attrs[assoc_prop.target_collection] + scalar = not prop.uselist + model = prop.mapper.class_ + attr = class_mapper(model).attrs[assoc_prop.value_attr] + + if isinstance(attr, ColumnProperty): + field = convert_sqlalchemy_column(attr, registry, resolver, **field_kwargs) + if not scalar: + # repackage as List + field.__dict__["_type"] = graphene.List(field.type) + return field + elif isinstance(attr, RelationshipProperty): + return convert_sqlalchemy_relationship( + attr, + obj_type, + connection_field_factory, + field_kwargs.pop("batching", batching), + assoc_prop.value_attr, + **field_kwargs, + ).get_type() + else: + raise TypeError( + "Unsupported association proxy target type: {} for prop {} on type {}. " + "Please disable the conversion of this field using an ORMField.".format( + type(attr), assoc_prop, obj_type + ) + ) + # else, not supported + + return graphene.Dynamic(dynamic_type) + + +def convert_sqlalchemy_relationship( + relationship_prop, + obj_type, + connection_field_factory, + batching, + orm_field_name, + **field_kwargs, +): + """ + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param function|None connection_field_factory: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Dynamic + """ def dynamic_type(): - _type = registry.get_type_for_model(model) - if not _type: + """:rtype: Field|None""" + direction = relationship_prop.direction + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) + batching_ = batching if is_selectin_available else False + + if not child_type: return None - if direction == interfaces.MANYTOONE or not relationship.uselist: - return Field(_type) - elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - if _type._meta.connection: - return connection_field_factory(relationship, registry) - return Field(List(_type)) - return Dynamic(dynamic_type) + if direction == interfaces.MANYTOONE or not relationship_prop.uselist: + return _convert_o2o_or_m2o_relationship( + relationship_prop, obj_type, batching_, orm_field_name, **field_kwargs + ) + + if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): + return _convert_o2m_or_m2m_relationship( + relationship_prop, + obj_type, + batching_, + connection_field_factory, + **field_kwargs, + ) + + return graphene.Dynamic(dynamic_type) + + +def _convert_o2o_or_m2o_relationship( + relationship_prop, obj_type, batching, orm_field_name, **field_kwargs +): + """ + Convert one-to-one or many-to-one relationshsip. Return an object field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) + + resolver = get_custom_resolver(obj_type, orm_field_name) + if resolver is None: + resolver = ( + get_batch_resolver(relationship_prop) + if batching + else get_attr_resolver(obj_type, relationship_prop.key) + ) + + return graphene.Field(child_type, resolver=resolver, **field_kwargs) + + +def _convert_o2m_or_m2m_relationship( + relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs +): + """ + Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param function|None connection_field_factory: + :param dict field_kwargs: + :rtype: Field + """ + from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory + + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) + + if not child_type._meta.connection: + # check if we need to use non-null fields + list_type = ( + graphene.NonNull(graphene.List(graphene.NonNull(child_type))) + if use_non_null_many_relationships + else graphene.List(child_type) + ) + + return graphene.Field(list_type, **field_kwargs) + # TODO Allow override of connection_field_factory and resolver via ORMField + if connection_field_factory is None: + connection_field_factory = ( + BatchSQLAlchemyConnectionField.from_relationship + if batching + else default_connection_field_factory + ) + + return connection_field_factory( + relationship_prop, obj_type._meta.registry, **field_kwargs + ) + + +def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): + if "type_" not in field_kwargs: + field_kwargs["type_"] = convert_hybrid_property_return_type(hybrid_prop) + + if "description" not in field_kwargs: + field_kwargs["description"] = getattr(hybrid_prop, "__doc__", None) -def convert_sqlalchemy_hybrid_method(hybrid_item): - return String(description=getattr(hybrid_item, "__doc__", None), required=False) + return graphene.Field(resolver=resolver, **field_kwargs) -def convert_sqlalchemy_composite(composite, registry): - converter = registry.get_converter_for_composite(composite.composite_class) +def convert_sqlalchemy_composite(composite_prop, registry, resolver): + converter = registry.get_converter_for_composite(composite_prop.composite_class) if not converter: try: raise Exception( "Don't know how to convert the composite field %s (%s)" - % (composite, composite.composite_class) + % (composite_prop, composite_prop.composite_class) ) except AttributeError: # handle fields that are not attached to a class yet (don't have a parent) raise Exception( "Don't know how to convert the composite field %r (%s)" - % (composite, composite.composite_class) + % (composite_prop, composite_prop.composite_class) ) - return converter(composite, registry) + + # TODO Add a way to override composite fields default parameters + return converter(composite_prop, registry) def _register_composite_class(cls, registry=None): @@ -78,116 +309,402 @@ def inner(fn): convert_sqlalchemy_composite.register = _register_composite_class -def convert_sqlalchemy_column(column, registry=None): - return convert_sqlalchemy_type(getattr(column, "type", None), column, registry) +def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): + column = column_prop.columns[0] + # The converter expects a type to find the right conversion function. + # If we get an instance instead, we need to convert it to a type. + # The conversion function will still be able to access the instance via the column argument. + if "type_" not in field_kwargs: + column_type = getattr(column, "type", None) + if not isinstance(column_type, type): + column_type = type(column_type) + field_kwargs.setdefault( + "type_", + convert_sqlalchemy_type(column_type, column=column, registry=registry), + ) + field_kwargs.setdefault("required", not is_column_nullable(column)) + field_kwargs.setdefault("description", get_column_doc(column)) + return graphene.Field(resolver=resolver, **field_kwargs) -@singledispatch -def convert_sqlalchemy_type(type, column, registry=None): - raise Exception( - "Don't know how to convert the SQLAlchemy field %s (%s)" - % (column, column.__class__) - ) +@singledispatchbymatchfunction +def convert_sqlalchemy_type( # noqa + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, + **kwargs, +): + if replace_type_vars and type_arg in replace_type_vars: + return replace_type_vars[type_arg] -@convert_sqlalchemy_type.register(types.Date) -@convert_sqlalchemy_type.register(types.Time) -@convert_sqlalchemy_type.register(types.String) -@convert_sqlalchemy_type.register(types.Text) -@convert_sqlalchemy_type.register(types.Unicode) -@convert_sqlalchemy_type.register(types.UnicodeText) -@convert_sqlalchemy_type.register(postgresql.UUID) -@convert_sqlalchemy_type.register(postgresql.INET) -@convert_sqlalchemy_type.register(postgresql.CIDR) -@convert_sqlalchemy_type.register(TSVectorType) -def convert_column_to_string(type, column, registry=None): - return String( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + # No valid type found, raise an error + raise TypeError( + "Don't know how to convert the SQLAlchemy field %s (%s, %s). " + "Please add a type converter or set the type manually using ORMField(type_=your_type)" + % (column, column.__class__ or "no column provided", type_arg) + ) -@convert_sqlalchemy_type.register(types.DateTime) -def convert_column_to_datetime(type, column, registry=None): - from graphene.types.datetime import DateTime - return DateTime( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) +@convert_sqlalchemy_type.register(safe_isinstance(DeclarativeMeta)) +def convert_sqlalchemy_model_using_registry( + type_arg: Any, registry: Registry = None, **kwargs +): + registry_ = registry or get_global_registry() + def get_type_from_registry(): + existing_graphql_type = registry_.get_type_for_model(type_arg) + if existing_graphql_type: + return existing_graphql_type -@convert_sqlalchemy_type.register(types.SmallInteger) -@convert_sqlalchemy_type.register(types.Integer) -def convert_column_to_int_or_id(type, column, registry=None): - if column.primary_key: - return ID( - description=get_column_doc(column), - required=not (is_column_nullable(column)), + raise TypeError( + "No model found in Registry for type %s. " + "Only references to SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed." % type_arg ) + + return get_type_from_registry() + + +@convert_sqlalchemy_type.register(safe_issubclass(graphene.ObjectType)) +def convert_object_type(type_arg: Any, **kwargs): + return type_arg + + +@convert_sqlalchemy_type.register(safe_issubclass(graphene.Scalar)) +def convert_scalar_type(type_arg: Any, **kwargs): + return type_arg + + +@convert_sqlalchemy_type.register(safe_isinstance(TypeVar)) +def convert_type_var(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs): + return replace_type_vars[type_arg] + + +@convert_sqlalchemy_type.register(column_type_eq(str)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.String)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Unicode)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.UnicodeText)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.INET)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.CIDR)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.TSVectorType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.EmailType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.URLType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.IPAddressType)) +def convert_column_to_string(type_arg: Any, **kwargs): + return graphene.String + + +@convert_sqlalchemy_type.register(column_type_eq(postgresql.UUID)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.UUIDType)) +@convert_sqlalchemy_type.register(column_type_eq(uuid.UUID)) +def convert_column_to_uuid( + type_arg: Any, + **kwargs, +): + return graphene.UUID + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.DateTime)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.datetime)) +def convert_column_to_datetime( + type_arg: Any, + **kwargs, +): + return graphene.DateTime + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Time)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.time)) +def convert_column_to_time( + type_arg: Any, + **kwargs, +): + return graphene.Time + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Date)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.date)) +def convert_column_to_date( + type_arg: Any, + **kwargs, +): + return graphene.Date + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.SmallInteger)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Integer)) +@convert_sqlalchemy_type.register(column_type_eq(int)) +def convert_column_to_int_or_id( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + # fixme drop the primary key processing from here in another pr + if column is not None: + if getattr(column, "primary_key", False) is True: + return graphene.ID + return graphene.Int + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Boolean)) +@convert_sqlalchemy_type.register(column_type_eq(bool)) +def convert_column_to_boolean( + type_arg: Any, + **kwargs, +): + return graphene.Boolean + + +@convert_sqlalchemy_type.register(column_type_eq(float)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Float)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Numeric)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.BigInteger)) +def convert_column_to_float( + type_arg: Any, + **kwargs, +): + return graphene.Float + + +@convert_sqlalchemy_type.register(column_type_eq(postgresql.ENUM)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Enum)) +def convert_enum_to_enum( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("SQL-Enum conversion requires a column") + + return lambda: enum_for_sa_enum(column.type, registry or get_global_registry()) + + +# TODO Make ChoiceType conversion consistent with other enums +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ChoiceType)) +def convert_choice_to_enum( + type_arg: sqa_utils.ChoiceType, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("ChoiceType conversion requires a column") + + name = "{}_{}".format(column.table.name, column.key).upper() + if isinstance(column.type.type_impl, EnumTypeImpl): + # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta + # do not use from_enum here because we can have more than one enum column in table + return graphene.Enum(name, list((v.name, v.value) for v in column.type.choices)) else: - return Int( - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) + return graphene.Enum(name, column.type.choices) -@convert_sqlalchemy_type.register(types.Boolean) -def convert_column_to_boolean(type, column, registry=None): - return Boolean( - description=get_column_doc(column), required=not (is_column_nullable(column)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ScalarListType)) +def convert_scalar_list_to_list( + type_arg: Any, + **kwargs, +): + return graphene.List(graphene.String) + + +def init_array_list_recursive(inner_type, n): + return ( + inner_type + if n == 0 + else graphene.List(init_array_list_recursive(inner_type, n - 1)) ) -@convert_sqlalchemy_type.register(types.Float) -@convert_sqlalchemy_type.register(types.Numeric) -@convert_sqlalchemy_type.register(types.BigInteger) -def convert_column_to_float(type, column, registry=None): - return Float( - description=get_column_doc(column), required=not (is_column_nullable(column)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.ARRAY)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.ARRAY)) +def convert_array_to_list( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("SQL-Array conversion requires a column") + item_type = column.type.item_type + if not isinstance(item_type, type): + item_type = type(item_type) + inner_type = convert_sqlalchemy_type( + item_type, column=column, registry=registry, **kwargs + ) + return graphene.List( + init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1) ) -@convert_sqlalchemy_type.register(types.Enum) -def convert_enum_to_enum(type, column, registry=None): - return Field( - lambda: enum_for_sa_enum(type, registry or get_global_registry()), - description=get_column_doc(column), - required=not (is_column_nullable(column)), +@convert_sqlalchemy_type.register(column_type_eq(postgresql.HSTORE)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.JSON)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.JSONB)) +def convert_json_to_string( + type_arg: Any, + **kwargs, +): + return JSONString + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.JSONType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.JSON)) +def convert_json_type_to_string( + type_arg: Any, + **kwargs, +): + return JSONString + + +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Variant)) +def convert_variant_to_impl_type( + type_arg: sqa_types.Variant, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("Vaiant conversion requires a column") + + type_impl = column.type.impl + if not isinstance(type_impl, type): + type_impl = type(type_impl) + return convert_sqlalchemy_type( + type_impl, column=column, registry=registry, **kwargs ) -@convert_sqlalchemy_type.register(ChoiceType) -def convert_choice_to_enum(type, column, registry=None): - name = "{}_{}".format(column.table.name, column.name).upper() - return Enum(name, type.choices, description=get_column_doc(column)) +@convert_sqlalchemy_type.register(column_type_eq(Decimal)) +def convert_sqlalchemy_hybrid_property_type_decimal(type_arg: Any, **kwargs): + # The reason Decimal should be serialized as a String is because this is a + # base10 type used in things like money, and string allows it to not + # lose precision (which would happen if we downcasted to a Float, for example) + return graphene.String -@convert_sqlalchemy_type.register(ScalarListType) -def convert_scalar_list_to_list(type, column, registry=None): - return List(String, description=get_column_doc(column)) +def is_union(type_arg: Any, **kwargs) -> bool: + if sys.version_info >= (3, 10): + from types import UnionType + if isinstance(type_arg, UnionType): + return True + return getattr(type_arg, "__origin__", None) == typing.Union -@convert_sqlalchemy_type.register(postgresql.ARRAY) -def convert_postgres_array_to_list(_type, column, registry=None): - graphene_type = convert_sqlalchemy_type(column.type.item_type, column) - inner_type = type(graphene_type) - return List( - inner_type, - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) +def graphene_union_for_py_union( + obj_types: typing.List[graphene.ObjectType], registry +) -> graphene.Union: + union_type = registry.get_union_for_object_types(obj_types) -@convert_sqlalchemy_type.register(postgresql.HSTORE) -@convert_sqlalchemy_type.register(postgresql.JSON) -@convert_sqlalchemy_type.register(postgresql.JSONB) -def convert_json_to_string(type, column, registry=None): - return JSONString( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + if union_type is None: + # Union Name is name of the three + union_name = "".join(sorted(obj_type._meta.name for obj_type in obj_types)) + union_type = graphene.Union.create_type(union_name, types=obj_types) + registry.register_union_type(union_type, obj_types) + + return union_type + + +@convert_sqlalchemy_type.register(is_union) +def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): + """ + Converts Unions (Union[X,Y], or X | Y for python > 3.10) to the corresponding graphene schema object. + Since Optionals are internally represented as Union[T, ], they are handled here as well. + The GQL Spec currently only allows for ObjectType unions: + GraphQL Unions represent an object that could be one of a list of GraphQL Object types, but provides for no + guaranteed fields between those types. + That's why we have to check for the nested types to be instances of graphene.ObjectType, except for the union case. -@convert_sqlalchemy_type.register(JSONType) -def convert_json_type_to_string(type, column, registry=None): - return JSONString( - description=get_column_doc(column), required=not (is_column_nullable(column)) + type(x) == _types.UnionType is necessary to support X | Y notation, but might break in future python releases. + """ + from .registry import get_global_registry + + # Option is actually Union[T, ] + # Just get the T out of the list of arguments by filtering out the NoneType + nested_types = list(filter(lambda x: not type(None) == x, type_arg.__args__)) + + # TODO redo this for , *args, **kwargs + # Map the graphene types to the nested types. + # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... + graphene_types = list(map(convert_sqlalchemy_type, nested_types)) + + # If only one type is left after filtering out NoneType, the Union was an Optional + if len(graphene_types) == 1: + return graphene_types[0] + + # Now check if every type is instance of an ObjectType + if not all( + isinstance(graphene_type, type(graphene.ObjectType)) + for graphene_type in graphene_types + ): + raise ValueError( + "Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " + "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " + "or use an ORMField to override this behaviour." + ) + + return graphene_union_for_py_union( + cast(typing.List[graphene.ObjectType], list(graphene_types)), + get_global_registry(), ) + + +@convert_sqlalchemy_type.register( + lambda x: getattr(x, "__origin__", None) in [list, typing.List] +) +def convert_sqlalchemy_hybrid_property_type_list_t(type_arg: Any, **kwargs): + # type is either list[T] or List[T], generic argument at __args__[0] + internal_type = type_arg.__args__[0] + + graphql_internal_type = convert_sqlalchemy_type(internal_type, **kwargs) + + return graphene.List(graphql_internal_type) + + +@convert_sqlalchemy_type.register(safe_isinstance(ForwardRef)) +def convert_sqlalchemy_hybrid_property_forwardref(type_arg: Any, **kwargs): + """ + Generate a lambda that will resolve the type at runtime + This takes care of self-references + """ + from .registry import get_global_registry + + def forward_reference_solver(): + model = registry_sqlalchemy_model_from_str(type_arg.__forward_arg__) + if not model: + raise TypeError( + "No model found in Registry for forward reference for type %s. " + "Only forward references to other SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed." % type_arg + ) + # Always fall back to string if no ForwardRef type found. + return get_global_registry().get_type_for_model(model) + + return forward_reference_solver + + +@convert_sqlalchemy_type.register(safe_isinstance(str)) +def convert_sqlalchemy_hybrid_property_bare_str(type_arg: str, **kwargs): + """ + Convert Bare String into a ForwardRef + """ + + return convert_sqlalchemy_type(ForwardRef(type_arg), **kwargs) + + +def convert_hybrid_property_return_type(hybrid_prop): + # Grab the original method's return type annotations from inside the hybrid property + return_type_annotation = hybrid_prop.fget.__annotations__.get("return", None) + if not return_type_annotation: + raise TypeError( + "Cannot convert hybrid property type {} to a valid graphene type. " + "Please make sure to annotate the return type of the hybrid property or use the " + "type_ attribute of ORMField to set the type.".format(hybrid_prop) + ) + + return convert_sqlalchemy_type(return_type_annotation, column=hybrid_prop) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 6b84bf52..97f8997c 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column +from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType from graphene import Argument, Enum, List @@ -18,9 +18,7 @@ def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): The Enum value names are converted to upper case if necessary. """ if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)) enum_class = sa_enum.enum_class if enum_class: if all(to_enum_value_name(key) == key for key in enum_class.__members__): @@ -45,9 +43,7 @@ def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): def enum_for_sa_enum(sa_enum, registry): """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)) enum = registry.get_graphene_enum_for_sa_enum(sa_enum) if not enum: enum = _convert_sa_to_graphene_enum(sa_enum) @@ -60,20 +56,19 @@ def enum_for_field(obj_type, field_name): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + raise TypeError("Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) if not field_name or not isinstance(field_name, str): - raise TypeError( - "Expected a field name, but got: {!r}".format(field_name)) + raise TypeError("Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) if orm_field is None: raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name)) - if not isinstance(orm_field, Column): + if not isinstance(orm_field, ColumnProperty): raise TypeError( "{}.{} does not map to model column".format(obj_type._meta.name, field_name) ) - sa_enum = orm_field.type + column = orm_field.columns[0] + sa_enum = column.type if not isinstance(sa_enum, SQLAlchemyEnumType): raise TypeError( "{}.{} does not map to enum column".format(obj_type._meta.name, field_name) @@ -138,15 +133,16 @@ def sort_enum_for_object_type( if only_fields and field_name not in only_fields: continue orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) - if not isinstance(orm_field, Column): + if not isinstance(orm_field, ColumnProperty): continue - if only_indexed and not (orm_field.primary_key or orm_field.index): + column = orm_field.columns[0] + if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(orm_field.name, True) - asc_value = EnumValue(asc_name, orm_field.asc()) - desc_name = get_name(orm_field.name, False) - desc_value = EnumValue(desc_name, orm_field.desc()) - if orm_field.primary_key: + asc_name = get_name(field_name, True) + asc_value = EnumValue(asc_name, column.asc()) + desc_name = get_name(field_name, False) + desc_value = EnumValue(desc_name, column.desc()) + if column.primary_key: default.append(asc_value) members.extend(((asc_name, asc_value), (desc_name, desc_value))) enum = Enum(name, members) @@ -164,7 +160,7 @@ def sort_argument_for_object_type( get_symbol_name=None, has_default=True, ): - """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + """ "Returns Graphene Argument for sorting the given SQLAlchemyObjectType. Parameters - obj_type : SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 3ad15a92..ef798852 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,65 +1,175 @@ -import logging +import enum +import warnings from functools import partial from promise import Promise, is_thenable from sqlalchemy.orm.query import Query from graphene.relay import Connection, ConnectionField -from graphene.relay.connection import PageInfo -from graphql_relay.connection.arrayconnection import connection_from_list_slice +from graphene.relay.connection import connection_adapter, page_info_adapter +from graphql_relay import connection_from_array_slice -from .utils import get_query +from .batching import get_batch_resolver +from .filters import BaseTypeFilter +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + EnumValue, + get_nullable_type, + get_query, + get_session, +) -log = logging.getLogger() +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession -class UnsortedSQLAlchemyConnectionField(ConnectionField): +class SQLAlchemyConnectionField(ConnectionField): @property def type(self): from .types import SQLAlchemyObjectType - _type = super(ConnectionField, self).type - if issubclass(_type, Connection): - return _type - assert issubclass(_type, SQLAlchemyObjectType), ( + type_ = super(ConnectionField, self).type + nullable_type = get_nullable_type(type_) + if issubclass(nullable_type, Connection): + return type_ + assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" - ).format(_type.__name__) - assert _type._meta.connection, "The type {} doesn't have a connection".format( - _type.__name__ + ).format(nullable_type.__name__) + assert nullable_type.connection, "The type {} doesn't have a connection".format( + nullable_type.__name__ ) - return _type._meta.connection + assert type_ == nullable_type, ( + "Passing a SQLAlchemyObjectType instance is deprecated. " + "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" + ) + return nullable_type.connection + + def __init__(self, type_, *args, **kwargs): + nullable_type = get_nullable_type(type_) + # Handle Sorting and Filtering + if ( + "sort" not in kwargs + and nullable_type + and issubclass(nullable_type, Connection) + ): + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + nullable_type.__name__ + ) + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + + if ( + "filter" not in kwargs + and nullable_type + and issubclass(nullable_type, Connection) + ): + # Only add filtering if a filter argument exists on the object type + filter_argument = nullable_type.Edge.node._type.get_filter_argument() + if filter_argument: + kwargs.setdefault("filter", filter_argument) + elif "filter" in kwargs and kwargs["filter"] is None: + del kwargs["filter"] + + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) @property def model(self): - return self.type._meta.node._meta.model + return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): + def get_query(cls, model, info, sort=None, filter=None, **args): query = get_query(model, info.context) if sort is not None: - if isinstance(sort, str): - query = query.order_by(sort.value) - else: - query = query.order_by(*(col.value for col in sort)) + if not isinstance(sort, list): + sort = [sort] + sort_args = [] + # ensure consistent handling of graphene Enums, enum values and + # plain strings + for item in sort: + if isinstance(item, enum.Enum): + sort_args.append(item.value.value) + elif isinstance(item, EnumValue): + sort_args.append(item.value) + else: + sort_args.append(item) + query = query.order_by(*sort_args) + + if filter is not None: + assert isinstance(filter, dict) + filter_type: BaseTypeFilter = type(filter) + query, clauses = filter_type.execute_filters(query, filter) + query = query.filter(*clauses) return query @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): + session = get_session(info.context) if resolved is None: - resolved = cls.get_query(model, info, **args) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return await cls.resolve_connection_async( + connection_type, model, info, args, resolved + ) + + return get_result() + + else: + resolved = cls.get_query(model, info, **args) if isinstance(resolved, Query): _len = resolved.count() else: _len = len(resolved) - connection = connection_from_list_slice( - resolved, - args, + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, slice_start=0, - list_length=_len, - list_slice_length=_len, - connection_type=connection_type, - pageinfo_type=PageInfo, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, edge_type=connection_type.Edge, + page_info_type=page_info_adapter, + ) + connection.iterable = resolved + connection.length = _len + return connection + + @classmethod + async def resolve_connection_async( + cls, connection_type, model, info, args, resolved + ): + session = get_session(info.context) + if resolved is None: + query = cls.get_query(model, info, **args) + resolved = (await session.scalars(query)).all() + if isinstance(resolved, Query): + _len = resolved.count() + else: + _len = len(resolved) + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, + slice_start=0, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, + edge_type=connection_type.Edge, + page_info_type=page_info_adapter, ) connection.iterable = resolved connection.length = _len @@ -75,59 +185,108 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg return on_resolve(resolved) - def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, parent_resolver, self.type, self.model) + def wrap_resolve(self, parent_resolver): + return partial( + self.connection_resolver, + parent_resolver, + get_nullable_type(self.type), + self.model, + ) -class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): - def __init__(self, type, *args, **kwargs): - if "sort" not in kwargs and issubclass(type, Connection): - # Let super class raise if type is not a Connection - try: - kwargs.setdefault("sort", type.Edge.node._type.sort_argument()) - except (AttributeError, TypeError): - raise TypeError( - 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' - " to None to disabling the creation of the sort query argument".format( - type.__name__ - ) - ) - elif "sort" in kwargs and kwargs["sort"] is None: - del kwargs["sort"] - super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) +# TODO Remove in next major version +class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField): + def __init__(self, type_, *args, **kwargs): + if "sort" in kwargs and kwargs["sort"] is not None: + warnings.warn( + "UnsortedSQLAlchemyConnectionField does not support sorting. " + "All sorting arguments will be ignored." + ) + kwargs["sort"] = None + warnings.warn( + "UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyConnectionField instead and either don't " + "provide the `sort` argument or set it to None if you do not want sorting.", + DeprecationWarning, + ) + super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) + + +class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): + """ + This is currently experimental. + The API and behavior may change in future versions. + Use at your own risk. + """ + + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + if root is None: + resolved = resolver(root, info, **args) + on_resolve = partial( + cls.resolve_connection, connection_type, model, info, args + ) + else: + relationship_prop = None + for relationship in root.__class__.__mapper__.relationships: + if relationship.mapper.class_ == model: + relationship_prop = relationship + break + resolved = get_batch_resolver(relationship_prop)(root, info, **args) + on_resolve = partial( + cls.resolve_connection, connection_type, root, info, args + ) + + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) + + @classmethod + def from_relationship(cls, relationship, registry, **field_kwargs): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return cls( + model_type.connection, + resolver=get_batch_resolver(relationship), + **field_kwargs, + ) -def default_connection_field_factory(relationship, registry): +def default_connection_field_factory(relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return createConnectionField(model_type) + return __connectionFactory(model_type, **field_kwargs) # TODO Remove in next major version __connectionFactory = UnsortedSQLAlchemyConnectionField -def createConnectionField(_type): - log.warning( - 'createConnectionField is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' +def createConnectionField(type_, **field_kwargs): + warnings.warn( + "createConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", + DeprecationWarning, ) - return __connectionFactory(_type) + return __connectionFactory(type_, **field_kwargs) def registerConnectionFieldFactory(factoryMethod): - log.warning( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + warnings.warn( + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", + DeprecationWarning, ) global __connectionFactory __connectionFactory = factoryMethod def unregisterConnectionFieldFactory(): - log.warning( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + warnings.warn( + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", + DeprecationWarning, ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py new file mode 100644 index 00000000..cbe3d09d --- /dev/null +++ b/graphene_sqlalchemy/filters.py @@ -0,0 +1,532 @@ +import re +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union + +from graphql import Undefined +from sqlalchemy import and_, not_, or_ +from sqlalchemy.orm import Query, aliased # , selectinload + +import graphene +from graphene.types.inputobjecttype import ( + InputObjectTypeContainer, + InputObjectTypeOptions, +) +from graphene_sqlalchemy.utils import is_list + +BaseTypeFilterSelf = TypeVar( + "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer +) + + +class SQLAlchemyFilterInputField(graphene.InputField): + def __init__( + self, + type_, + model_attr, + name=None, + default_value=Undefined, + deprecation_reason=None, + description=None, + required=False, + _creation_counter=None, + **extra_args, + ): + super(SQLAlchemyFilterInputField, self).__init__( + type_, + name, + default_value, + deprecation_reason, + description, + required, + _creation_counter, + **extra_args, + ) + + self.model_attr = model_attr + + +def _get_functions_by_regex( + regex: str, subtract_regex: str, class_: Type +) -> List[Tuple[str, Dict[str, Any]]]: + function_regex = re.compile(regex) + + matching_functions = [] + + # Search the entire class for functions matching the filter regex + for fn in dir(class_): + func_attr = getattr(class_, fn) + # Check if attribute is a function + if callable(func_attr) and function_regex.match(fn): + # add function and attribute name to the list + matching_functions.append( + (re.sub(subtract_regex, "", fn), func_attr.__annotations__) + ) + return matching_functions + + +class BaseTypeFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, filter_fields=None, model=None, _meta=None, **options + ): + from graphene_sqlalchemy.converter import convert_sqlalchemy_type + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls) + + new_filter_fields = {} + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in logic_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + + replace_type_vars = {BaseTypeFilterSelf: cls} + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + + if _meta.fields: + _meta.fields.update(filter_fields) + else: + _meta.fields = filter_fields + _meta.fields.update(new_filter_fields) + + _meta.model = model + + super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + @classmethod + def and_logic( + cls, + query, + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], + model_alias=None, + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + clauses = [] + for value in val: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, value, model_alias=model_alias + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [and_(*clauses)] + + @classmethod + def or_logic( + cls, + query, + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], + model_alias=None, + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + + clauses = [] + for value in val: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, value, model_alias=model_alias + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [or_(*clauses)] + + @classmethod + def execute_filters( + cls, query, filter_dict: Dict[str, Any], model_alias=None + ) -> Tuple[Query, List[Any]]: + model = cls._meta.model + if model_alias: + model = model_alias + + clauses = [] + + for field, field_filters in filter_dict.items(): + # Relationships are Dynamic, we need to resolve them fist + # Maybe we can cache these dynamics to improve efficiency + # Check with a profiler is required to determine necessity + input_field = cls._meta.fields[field] + if isinstance(input_field, graphene.Dynamic): + input_field = input_field.get_type() + field_filter_type = input_field.type + else: + field_filter_type = cls._meta.fields[field].type + # raise Exception + # TODO we need to save the relationship props in the meta fields array + # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) + if field == "and": + query, _clauses = cls.and_logic( + query, field_filter_type.of_type, field_filters, model_alias=model + ) + clauses.extend(_clauses) + elif field == "or": + query, _clauses = cls.or_logic( + query, field_filter_type.of_type, field_filters, model_alias=model + ) + clauses.extend(_clauses) + else: + # Get the model attr from the inputfield in case the field is aliased in graphql + model_field = getattr(model, input_field.model_attr or field) + if issubclass(field_filter_type, BaseTypeFilter): + # Get the model to join on the Filter Query + joined_model = field_filter_type._meta.model + # Always alias the model + joined_model_alias = aliased(joined_model) + # Join the aliased model onto the query + query = query.join(model_field.of_type(joined_model_alias)) + # Pass the joined query down to the next object type filter for processing + query, _clauses = field_filter_type.execute_filters( + query, field_filters, model_alias=joined_model_alias + ) + clauses.extend(_clauses) + if issubclass(field_filter_type, RelationshipFilter): + # TODO see above; not yet working + relationship_prop = field_filter_type._meta.model + # Always alias the model + # joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + # todo should we use selectinload here instead of join for large lists? + + query, _clauses = field_filter_type.execute_filters( + query, model, model_field, field_filters, relationship_prop + ) + clauses.extend(_clauses) + elif issubclass(field_filter_type, FieldFilter): + query, _clauses = field_filter_type.execute_filters( + query, model_field, field_filters + ) + clauses.extend(_clauses) + + return query, clauses + + +ScalarFilterInputType = TypeVar("ScalarFilterInputType") + + +class FieldFilterOptions(InputObjectTypeOptions): + graphene_type: Type = None + + +class FieldFilter(graphene.InputObjectType): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + @classmethod + def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): + from .converter import convert_sqlalchemy_type + + # get all filter functions + + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = FieldFilterOptions(cls) + + if not _meta.graphene_type: + _meta.graphene_type = graphene_type + + new_filter_fields = {} + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) + + # Add all fields to the meta options. graphene.InputbjectType will take care of the rest + if _meta.fields: + _meta.fields.update(new_filter_fields) + else: + _meta.fields = new_filter_fields + + # Pass modified meta to the super class + super(FieldFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + @classmethod + def in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.in_(val) + + @classmethod + def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.notin_(val) + + # TODO add like/ilike + + @classmethod + def execute_filters( + cls, query, field, filter_dict: Dict[str, any] + ) -> Tuple[Query, List[Any]]: + clauses = [] + for filt, val in filter_dict.items(): + clause = getattr(cls, filt + "_filter")(query, field, val) + if isinstance(clause, tuple): + query, clause = clause + clauses.append(clause) + + return query, clauses + + +class SQLEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val.value + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val.value) + + +class PyEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + +class StringFilter(FieldFilter): + class Meta: + graphene_type = graphene.String + + @classmethod + def like_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.like(val) + + @classmethod + def ilike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.ilike(val) + + @classmethod + def notlike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.notlike(val) + + +class BooleanFilter(FieldFilter): + class Meta: + graphene_type = graphene.Boolean + + +class OrderedFilter(FieldFilter): + class Meta: + abstract = True + + @classmethod + def gt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field > val + + @classmethod + def gte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field >= val + + @classmethod + def lt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field < val + + @classmethod + def lte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field <= val + + +class NumberFilter(OrderedFilter): + """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" + + class Meta: + abstract = True + + +class FloatFilter(NumberFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.Float + + +class IntFilter(NumberFilter): + class Meta: + graphene_type = graphene.Int + + +class DateFilter(OrderedFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.Date + + +class DateTimeFilter(OrderedFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.DateTime + + +class IdFilter(FieldFilter): + class Meta: + graphene_type = graphene.ID + + +class RelationshipFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, base_type_filter=None, model=None, _meta=None, **options + ): + if not base_type_filter: + raise Exception("Relationship Filters must be specific to an object type") + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + # get all filter functions + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) + + relationship_filters = {} + + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + if is_list(_annotations["val"]): + relationship_filters.update( + {field_name: graphene.InputField(graphene.List(base_type_filter))} + ) + else: + relationship_filters.update( + {field_name: graphene.InputField(base_type_filter)} + ) + + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: + _meta.fields.update(relationship_filters) + else: + _meta.fields = relationship_filters + + _meta.model = model + _meta.base_type_filter = base_type_filter + super(RelationshipFilter, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + @classmethod + def contains_filter( + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], + ): + clauses = [] + for v in val: + # Always alias the model + joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + query = query.join(field.of_type(joined_model_alias)).distinct() + # pass the alias so group can join group + query, _clauses = cls._meta.base_type_filter.execute_filters( + query, v, model_alias=joined_model_alias + ) + clauses.append(and_(*_clauses)) + return query, [or_(*clauses)] + + @classmethod + def contains_exactly_filter( + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], + ): + raise NotImplementedError + + @classmethod + def execute_filters( + cls: Type[FieldFilter], + query, + parent_model, + field, + filter_dict: Dict, + relationship_prop, + ) -> Tuple[Query, List[Any]]: + query, clauses = (query, []) + + for filt, val in filter_dict.items(): + query, _clauses = getattr(cls, filt + "_filter")( + query, parent_model, field, relationship_prop, val + ) + clauses += _clauses + + return query, clauses diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index acfa744b..b959d221 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,8 +1,15 @@ +import inspect from collections import defaultdict +from typing import TYPE_CHECKING, List, Type from sqlalchemy.types import Enum as SQLAlchemyEnumType +import graphene from graphene import Enum +from graphene.types.base import BaseType + +if TYPE_CHECKING: # pragma: no_cover + from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter class Registry(object): @@ -13,16 +20,37 @@ def __init__(self): self._registry_composites = {} self._registry_enums = {} self._registry_sort_enums = {} + self._registry_unions = {} + self._registry_scalar_filters = {} + self._registry_base_type_filters = {} + self._registry_relationship_filters = {} - def register(self, obj_type): - from .types import SQLAlchemyObjectType + self._init_base_filters() - if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType - ): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + def _init_base_filters(self): + import graphene_sqlalchemy.filters as gsqa_filters + + from .filters import FieldFilter + + field_filter_classes = [ + filter_cls[1] + for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass) + if ( + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) ) + ] + for field_filter_class in field_filter_classes: + self.register_filter_for_scalar_type( + field_filter_class._meta.graphene_type, field_filter_class + ) + + def register(self, obj_type): + from .types import SQLAlchemyBase + + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase): + raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type)) assert obj_type._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( # 'SQLAlchemy model "{}" already associated with ' @@ -34,14 +62,10 @@ def get_type_for_model(self, model): return self._registry.get(model) def register_orm_field(self, obj_type, field_name, orm_field): - from .types import SQLAlchemyObjectType + from .types import SQLAlchemyBase - if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType - ): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) - ) + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase): + raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type)) if not field_name or not isinstance(field_name, str): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) self._registry_orm_fields[obj_type][field_name] = orm_field @@ -55,7 +79,7 @@ def register_composite_converter(self, composite, converter): def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) - def register_enum(self, sa_enum, graphene_enum): + def register_enum(self, sa_enum: SQLAlchemyEnumType, graphene_enum: Enum): if not isinstance(sa_enum, SQLAlchemyEnumType): raise TypeError( "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) @@ -67,10 +91,11 @@ def register_enum(self, sa_enum, graphene_enum): self._registry_enums[sa_enum] = graphene_enum - def get_graphene_enum_for_sa_enum(self, sa_enum): + def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType): return self._registry_enums.get(sa_enum) - def register_sort_enum(self, obj_type, sort_enum): + def register_sort_enum(self, obj_type, sort_enum: Enum): + from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( @@ -83,9 +108,130 @@ def register_sort_enum(self, obj_type, sort_enum): raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) self._registry_sort_enums[obj_type] = sort_enum - def get_sort_enum_for_object_type(self, obj_type): + def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) + def register_union_type( + self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] + ): + if not issubclass(union, graphene.Union): + raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) + + for obj_type in obj_types: + if not issubclass(obj_type, graphene.ObjectType): + raise TypeError( + "Expected Graphene ObjectType, but got: {!r}".format(obj_type) + ) + + self._registry_unions[frozenset(obj_types)] = union + + def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]): + return self._registry_unions.get(frozenset(obj_types)) + + # Filter Scalar Fields of Object Types + def register_filter_for_scalar_type( + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] + ): + from .filters import FieldFilter + + if not isinstance(scalar_type, type(graphene.Scalar)): + raise TypeError("Expected Scalar, but got: {!r}".format(scalar_type)) + + if not issubclass(filter_obj, FieldFilter): + raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj)) + self._registry_scalar_filters[scalar_type] = filter_obj + + def get_filter_for_sql_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import SQLEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = SQLEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_py_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import PyEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = PyEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_scalar_type( + self, scalar_type: Type[graphene.Scalar] + ) -> Type["FieldFilter"]: + from .filters import FieldFilter + + filter_type = self._registry_scalar_filters.get(scalar_type) + if not filter_type: + filter_type = FieldFilter.create_type( + f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type + ) + self._registry_scalar_filters[scalar_type] = filter_type + + return filter_type + + # TODO register enums automatically + def register_filter_for_enum_type( + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] + ): + from .filters import FieldFilter + + if not issubclass(enum_type, graphene.Enum): + raise TypeError("Expected Enum, but got: {!r}".format(enum_type)) + + if not issubclass(filter_obj, FieldFilter): + raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj)) + self._registry_scalar_filters[enum_type] = filter_obj + + # Filter Base Types + def register_filter_for_base_type( + self, + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], + ): + from .filters import BaseTypeFilter + + if not issubclass(base_type, BaseType): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) + + if not issubclass(filter_obj, BaseTypeFilter): + raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj)) + self._registry_base_type_filters[base_type] = filter_obj + + def get_filter_for_base_type(self, base_type: Type[BaseType]): + return self._registry_base_type_filters.get(base_type) + + # Filter Relationships between base types + def register_relationship_filter_for_base_type( + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] + ): + from .filters import RelationshipFilter + + if not isinstance(base_type, type(BaseType)): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) + + if not issubclass(filter_obj, RelationshipFilter): + raise TypeError( + "Expected RelationshipFilter, but got: {!r}".format(filter_obj) + ) + self._registry_relationship_filters[base_type] = filter_obj + + def get_relationship_filter_for_base_type( + self, base_type: Type[BaseType] + ) -> "RelationshipFilter": + return self._registry_relationship_filters.get(base_type) + registry = None diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py new file mode 100644 index 00000000..e8e61911 --- /dev/null +++ b/graphene_sqlalchemy/resolvers.py @@ -0,0 +1,26 @@ +from graphene.utils.get_unbound_function import get_unbound_function + + +def get_custom_resolver(obj_type, orm_field_name): + """ + Since `graphene` will call `resolve_` on a field only if it + does not have a `resolver`, we need to re-implement that logic here so + users are able to override the default resolvers that we provide. + """ + resolver = getattr(obj_type, "resolve_{}".format(orm_field_name), None) + if resolver: + return get_unbound_function(resolver) + + return None + + +def get_attr_resolver(obj_type, model_attr): + """ + In order to support field renaming via `ORMField.model_attr`, + we need to define resolver functions for each field. + + :param SQLAlchemyObjectType obj_type: + :param str model_attr: the name of the SQLAlchemy attribute + :rtype: Callable + """ + return lambda root, _info: getattr(root, model_attr, None) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 2825eb3c..2c749da7 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,32 +1,83 @@ import pytest +import pytest_asyncio from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import sessionmaker +from typing_extensions import Literal +import graphene +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 + +from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry -from .models import Base +from .models import Base, CompositeFullName -test_db_url = 'sqlite://' # use in-memory database for tests +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @pytest.fixture(autouse=True) def reset_registry(): reset_global_registry() + # Prevent tests that implicitly depend on Reporter from raising + # Tests that explicitly depend on this behavior should re-register a converter + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.Field(graphene.Int) + + +# make a typed literal for session one is sync and one is async +SESSION_TYPE = Literal["sync", "session_factory"] + + +@pytest.fixture(params=["sync", "async"]) +def session_type(request) -> SESSION_TYPE: + return request.param + + +@pytest.fixture +def async_session(session_type): + return session_type == "async" + + +@pytest.fixture +def test_db_url(session_type: SESSION_TYPE): + if session_type == "async": + return "sqlite+aiosqlite://" + else: + return "sqlite://" + + +@pytest.mark.asyncio +@pytest_asyncio.fixture(scope="function") +async def session_factory(session_type: SESSION_TYPE, test_db_url: str): + if session_type == "async": + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") + engine = create_async_engine(test_db_url) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) + await engine.dispose() + else: + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() -@pytest.yield_fixture(scope="function") -def session(): - db = create_engine(test_db_url) - connection = db.engine.connect() - transaction = connection.begin() - Base.metadata.create_all(connection) - # options = dict(bind=connection, binds={}) - session_factory = sessionmaker(bind=connection) - session = scoped_session(session_factory) +@pytest_asyncio.fixture(scope="function") +async def sync_session_factory(): + engine = create_engine("sqlite://") + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() - yield session - # Finalize test here - transaction.rollback() - connection.close() - session.remove() +@pytest_asyncio.fixture(scope="function") +def session(session_factory): + return session_factory() diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 12781cc5..e1ee9858 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -1,17 +1,48 @@ from __future__ import absolute_import +import datetime import enum +import uuid +from decimal import Decimal +from typing import List, Optional -from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table +# fmt: off +from sqlalchemy import ( + Column, + Date, + Enum, + ForeignKey, + Integer, + Numeric, + String, + Table, + func, +) +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import mapper, relationship +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import backref, column_property, composite, mapper, relationship +from sqlalchemy.sql.type_api import TypeEngine + +from graphene_sqlalchemy.tests.utils import wrap_select_func +from graphene_sqlalchemy.utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, +) + +# fmt: off +if SQL_VERSION_HIGHER_EQUAL_THAN_2: + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip +else: + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip +# fmt: on PetKind = Enum("cat", "dog", name="pet_kind") class HairKind(enum.Enum): - LONG = 'long' - SHORT = 'short' + LONG = "long" + SHORT = "short" Base = declarative_base() @@ -37,24 +68,112 @@ class Pet(Base): pet_kind = Column(PetKind, nullable=False) hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + legs = Column(Integer(), default=4) + + +class CompositeFullName(object): + def __init__(self, first_name, last_name): + self.first_name = first_name + self.last_name = last_name + + def __composite_values__(self): + return self.first_name, self.last_name + + def __repr__(self): + return "{} {}".format(self.first_name, self.last_name) + + +class ProxiedReporter(Base): + __tablename__ = "reporters_error" + id = Column(Integer(), primary_key=True) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + reporter = relationship("Reporter", uselist=False) + + # This is a hybrid property, we don't support proxies on hybrids yet + composite_prop = association_proxy("reporter", "composite_prop") class Reporter(Base): __tablename__ = "reporters" + id = Column(Integer(), primary_key=True) - first_name = Column(String(30)) - last_name = Column(String(30)) - email = Column(String()) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) - pets = relationship("Pet", secondary=association_table, backref="reporters") - articles = relationship("Article", backref="reporter") - favorite_article = relationship("Article", uselist=False) + pets = relationship( + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + lazy="selectin", + ) + articles = relationship( + "Article", backref=backref("reporter", lazy="selectin"), lazy="selectin" + ) + favorite_article = relationship("Article", uselist=False, lazy="selectin") + + @hybrid_property + def hybrid_prop_with_doc(self) -> str: + """Docstring test""" + return self.first_name + + @hybrid_property + def hybrid_prop(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_list(self) -> List[int]: + return [1, 2, 3] + + column_prop = column_property( + wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property" + ) + + composite_prop = composite( + CompositeFullName, first_name, last_name, doc="Composite" + ) + + headlines = association_proxy("articles", "headline") + + +articles_tags_table = Table( + "articles_tags", + Base.metadata, + Column("article_id", ForeignKey("articles.id")), + Column("tag_id", ForeignKey("tags.id")), +) - # total = column_property( - # select([ - # func.cast(func.count(PersonInfo.id), Float) - # ]) - # ) + +class Image(Base): + __tablename__ = "images" + id = Column(Integer(), primary_key=True) + external_id = Column(Integer()) + description = Column(String(30)) + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) class Article(Base): @@ -63,6 +182,32 @@ class Article(Base): headline = Column(String(100)) pub_date = Column(Date()) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + recommended_reads = association_proxy("reporter", "articles") + + # one-to-one relationship with image + image_id = Column(Integer(), ForeignKey("images.id"), unique=True) + image = relationship("Image", backref=backref("articles", uselist=False)) + + # many-to-many relationship with tags + tags = relationship("Tag", secondary=articles_tags_table, backref="articles") + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) class ReflectedEditor(type): @@ -75,4 +220,226 @@ def __subclasses__(cls): editor_table = Table("editors", Base.metadata, autoload=True) -mapper(ReflectedEditor, editor_table) +# TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4 +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + Base.registry.map_imperatively(ReflectedEditor, editor_table) +else: + mapper(ReflectedEditor, editor_table) + + +############################################ +# The models below are mainly used in the +# @hybrid_property type inference scenarios +############################################ + + +class ShoppingCartItem(Base): + __tablename__ = "shopping_cart_items" + + id = Column(Integer(), primary_key=True) + + @hybrid_property + def hybrid_prop_shopping_cart(self) -> List["ShoppingCart"]: + return [ShoppingCart(id=1)] + + +class ShoppingCart(Base): + __tablename__ = "shopping_carts" + + id = Column(Integer(), primary_key=True) + + # Standard Library types + + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_decimal(self) -> Decimal: + return Decimal("3.14") + + @hybrid_property + def hybrid_prop_date(self) -> datetime.date: + return datetime.datetime.now().date() + + @hybrid_property + def hybrid_prop_time(self) -> datetime.time: + return datetime.datetime.now().time() + + @hybrid_property + def hybrid_prop_datetime(self) -> datetime.datetime: + return datetime.datetime.now() + + # Lists and Nested Lists + + @hybrid_property + def hybrid_prop_list_int(self) -> List[int]: + return [1, 2, 3] + + @hybrid_property + def hybrid_prop_list_date(self) -> List[datetime.date]: + return [self.hybrid_prop_date, self.hybrid_prop_date, self.hybrid_prop_date] + + @hybrid_property + def hybrid_prop_nested_list_int(self) -> List[List[int]]: + return [ + self.hybrid_prop_list_int, + ] + + @hybrid_property + def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: + return [ + [ + self.hybrid_prop_list_int, + ], + ] + + # Other SQLAlchemy Instance + @hybrid_property + def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + # Other SQLAlchemy Instance with expression + @hybrid_property + def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + @hybrid_prop_first_shopping_cart_item_expression.expression + def hybrid_prop_first_shopping_cart_item_expression(cls): + return ShoppingCartItem + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: + return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] + + # Self-references + + @hybrid_property + def hybrid_prop_self_referential(self) -> "ShoppingCart": + return ShoppingCart(id=1) + + @hybrid_property + def hybrid_prop_self_referential_list(self) -> List["ShoppingCart"]: + return [ShoppingCart(id=1)] + + # Optional[T] + + @hybrid_property + def hybrid_prop_optional_self_referential(self) -> Optional["ShoppingCart"]: + return None + + # UUIDS + @hybrid_property + def hybrid_prop_uuid(self) -> uuid.UUID: + return uuid.uuid4() + + @hybrid_property + def hybrid_prop_uuid_list(self) -> List[uuid.UUID]: + return [ + uuid.uuid4(), + ] + + @hybrid_property + def hybrid_prop_optional_uuid(self) -> Optional[uuid.UUID]: + return None + + +class KeyedModel(Base): + __tablename__ = "test330" + id = Column(Integer(), primary_key=True) + reporter_number = Column("% reporter_number", Numeric, key="reporter_number") + + +############################################ +# For interfaces +############################################ + + +class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session + } + + +class NonAbstractPerson(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "non_abstract_person" + __mapper_args__ = { + "polymorphic_on": type, + "polymorphic_identity": "person", + } + + +class Employee(Person): + hire_date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "employee", + } + + +############################################ +# Custom Test Models +############################################ + + +class CustomIntegerColumn(HasExpressionLookup, TypeEngine): + """ + Custom Column Type that our converters don't recognize + Adapted from sqlalchemy.Integer + """ + + """A type for ``int`` integers.""" + + __visit_name__ = "integer" + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + @property + def python_type(self): + return int + + def literal_processor(self, dialect): + def process(value): + return str(int(value)) + + return process + + +class CustomColumnModel(Base): + __tablename__ = "customcolumnmodel" + + id = Column(Integer(), primary_key=True) + custom_col = Column(CustomIntegerColumn) + + +class CompositePrimaryKeyTestModel(Base): + __tablename__ = "compositekeytestmodel" + + first_name = Column(String(30), primary_key=True) + last_name = Column(String(30), primary_key=True) diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py new file mode 100644 index 00000000..e0f5d4bd --- /dev/null +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import + +import enum + +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table, func +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import column_property, relationship + +from graphene_sqlalchemy.tests.utils import wrap_select_func + +PetKind = Enum("cat", "dog", name="pet_kind") + + +class HairKind(enum.Enum): + LONG = "long" + SHORT = "short" + + +Base = declarative_base() + +association_table = Table( + "association", + Base.metadata, + Column("pet_id", Integer, ForeignKey("pets.id")), + Column("reporter_id", Integer, ForeignKey("reporters.id")), +) + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class Reporter(Base): + __tablename__ = "reporters" + + id = Column(Integer(), primary_key=True) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") + favorite_pet_kind = Column(PetKind) + pets = relationship( + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + ) + articles = relationship("Article", backref="reporter") + favorite_article = relationship("Article", uselist=False) + + column_prop = column_property( + wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property" + ) + + +class Article(Base): + __tablename__ = "articles" + id = Column(Integer(), primary_key=True) + headline = Column(String(100)) + pub_date = Column(Date()) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py new file mode 100644 index 00000000..5eccd5fc --- /dev/null +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -0,0 +1,971 @@ +import ast +import contextlib +import logging + +import pytest +from sqlalchemy import select + +import graphene +from graphene import Connection, relay + +from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory +from ..types import ORMField, SQLAlchemyObjectType +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) +from .models_batching import Article, HairKind, Pet, Reader, Reporter +from .utils import eventually_await_session, remove_cache_miss_stat, to_std_dicts + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession + + +class MockLoggingHandler(logging.Handler): + """Intercept and store log messages in a list.""" + + def __init__(self, *args, **kwargs): + self.messages = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.messages.append(record.getMessage()) + + +@contextlib.contextmanager +def mock_sqlalchemy_logging_handler(): + logging.basicConfig() + sql_logger = logging.getLogger("sqlalchemy.engine") + previous_level = sql_logger.level + + sql_logger.setLevel(logging.INFO) + mock_logging_handler = MockLoggingHandler() + mock_logging_handler.setLevel(logging.INFO) + sql_logger.addHandler(mock_logging_handler) + + yield mock_logging_handler + + sql_logger.setLevel(previous_level) + + +def get_async_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + batching = True + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + batching = True + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + async def resolve_articles(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() + + async def resolve_reporters(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + +def get_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + batching = True + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + batching = True + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, info): + session = get_session(info.context) + return session.query(Article).all() + + def resolve_reporters(self, info): + session = get_session(info.context) + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) + + +def get_full_relay_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = BatchSQLAlchemyConnectionField(ArticleType.connection) + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + readers = BatchSQLAlchemyConnectionField(ReaderType.connection) + + return graphene.Schema(query=Query) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) +async def test_many_to_one(sync_session_factory, schema_provider): + session = sync_session_factory() + schema = schema_provider() + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + article_1 = Article(headline="Article_1") + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline="Article_2") + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + result = await schema.execute_async( + """ + query { + articles { + headline + reporter { + firstName + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], + } + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN reporters" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) +async def test_one_to_one(sync_session_factory, schema_provider): + session = sync_session_factory() + schema = schema_provider() + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + article_1 = Article(headline="Article_1") + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline="Article_2") + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + + session = sync_session_factory() + result = await schema.execute_async( + """ + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], + } + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] + + +@pytest.mark.asyncio +async def test_one_to_many(sync_session_factory): + session = sync_session_factory() + + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + article_1 = Article(headline="Article_1") + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline="Article_2") + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline="Article_3") + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline="Article_4") + article_4.reporter = reporter_2 + session.add(article_4) + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + + session = sync_session_factory() + result = await schema.execute_async( + """ + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], + }, + }, + ], + } + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] + + +@pytest.mark.asyncio +async def test_many_to_many(sync_session_factory): + session = sync_session_factory() + + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + result = await schema.execute_async( + """ + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], + }, + }, + ], + } + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN pets" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] + + +def test_disable_batching_via_ormfield(sync_session_factory): + session = sync_session_factory() + reporter_1 = Reporter(first_name="Reporter_1") + session.add(reporter_1) + reporter_2 = Reporter(first_name="Reporter_2") + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + favorite_article = ORMField(batching=False) + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get("session").query(Reporter).all() + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + schema.execute( + """ + query { + reporters { + favoriteArticle { + headline + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] + assert len(select_statements) == 2 + + # Test one-to-many and many-to-many relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + schema.execute( + """ + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] + assert len(select_statements) == 2 + + +def test_batch_sorting_with_custom_ormfield(sync_session_factory): + session = sync_session_factory() + reporter_1 = Reporter(first_name="Reporter_1") + session.add(reporter_1) + reporter_2 = Reporter(first_name="Reporter_2") + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + firstname = ORMField(model_attr="first_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + result = schema.execute( + """ + query { + reporters(sort: [FIRSTNAME_DESC]) { + edges { + node { + firstname + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": { + "edges": [ + { + "node": { + "firstname": "Reporter_2", + } + }, + { + "node": { + "firstname": "Reporter_1", + } + }, + ] + } + } + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM reporters" in message + ] + assert len(select_statements) == 2 + + +@pytest.mark.asyncio +async def test_connection_factory_field_overrides_batching_is_false( + sync_session_factory, +): + session = sync_session_factory() + reporter_1 = Reporter(first_name="Reporter_1") + session.add(reporter_1) + reporter_2 = Reporter(first_name="Reporter_2") + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = False + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get("session").query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + await schema.execute_async( + """ + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + select_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] + else: + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] + assert len(select_statements) == 1 + + +def test_connection_factory_field_overrides_batching_is_true(sync_session_factory): + session = sync_session_factory() + reporter_1 = Reporter(first_name="Reporter_1") + session.add(reporter_1) + reporter_2 = Reporter(first_name="Reporter_2") + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + connection_field_factory = default_connection_field_factory + + articles = ORMField(batching=True) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get("session").query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = sync_session_factory() + schema.execute( + """ + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] + assert len(select_statements) == 2 + + +@pytest.mark.asyncio +async def test_batching_across_nested_relay_schema( + session_factory, async_session: bool +): + session = session_factory() + + for first_name in "fgerbhjikzutzxsdfdqqa": + reporter = Reporter( + first_name=first_name, + ) + session.add(reporter) + article = Article(headline="Article") + article.reporter = reporter + session.add(article) + reader = Reader(name="Reader") + reader.articles = [article] + session.add(reader) + + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + schema = get_full_relay_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = await schema.execute_async( + """ + query { + reporters { + edges { + node { + firstName + articles { + edges { + node { + id + readers { + edges { + node { + name + } + } + } + } + } + } + } + } + } + } + """, + context_value={"session": session}, + ) + messages = sqlalchemy_logging_handler.messages + + result = to_std_dicts(result.data) + select_statements = [message for message in messages if "SELECT" in message] + if async_session: + assert len(select_statements) == 2 # TODO: Figure out why async has less calls + else: + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than("1.3"): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] + + +@pytest.mark.asyncio +async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_factory): + session = session_factory() + + for first_name, email in zip("cadbbb", "aaabac"): + reporter_1 = Reporter(first_name=first_name, email=email) + session.add(reporter_1) + article_1 = Article(headline="headline") + article_1.reporter = reporter_1 + session.add(article_1) + + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + schema = get_full_relay_schema() + + session = session_factory() + result = await schema.execute_async( + """ + query { + reporters(sort: [FIRST_NAME_ASC, EMAIL_ASC]) { + edges { + node { + firstName + email + } + } + } + } + """, + context_value={"session": session}, + ) + + result = to_std_dicts(result.data) + assert [ + r["node"]["firstName"] + r["node"]["email"] + for r in result["reporters"]["edges"] + ] == ["aa", "ba", "bb", "bc", "ca", "da"] diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py new file mode 100644 index 00000000..dc656f41 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -0,0 +1,297 @@ +import asyncio + +import pytest +from sqlalchemy import select + +import graphene +from graphene import relay + +from ..types import SQLAlchemyObjectType +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) +from .models import Article, HairKind, Pet, Reporter +from .utils import eventually_await_session + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) + + +def get_async_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + async def resolve_articles(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() + + async def resolve_reporters(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + +def get_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, info): + return info.context.get("session").query(Article).all() + + def resolve_reporters(self, info): + return info.context.get("session").query(Reporter).all() + + return graphene.Schema(query=Query) + + +async def benchmark_query(session, benchmark, schema, query): + import nest_asyncio + + nest_asyncio.apply() + loop = asyncio.get_event_loop() + result = benchmark( + lambda: loop.run_until_complete( + schema.execute_async(query, context_value={"session": session}) + ) + ) + assert not result.errors + + +@pytest.fixture(params=[get_schema, get_async_schema]) +def schema_provider(request, async_session): + if async_session and request.param == get_schema: + pytest.skip("Cannot test sync schema with async sessions") + return request.param + + +@pytest.mark.asyncio +async def test_one_to_one(session_factory, benchmark, schema_provider): + session = session_factory() + schema = schema_provider() + + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + article_1 = Article(headline="Article_1") + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline="Article_2") + article_2.reporter = reporter_2 + session.add(article_2) + + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + await benchmark_query( + session, + benchmark, + schema, + """ + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """, + ) + + +@pytest.mark.asyncio +async def test_many_to_one(session_factory, benchmark, schema_provider): + session = session_factory() + schema = schema_provider() + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + article_1 = Article(headline="Article_1") + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline="Article_2") + article_2.reporter = reporter_2 + session.add(article_2) + await eventually_await_session(session, "flush") + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + await benchmark_query( + session, + benchmark, + schema, + """ + query { + articles { + headline + reporter { + firstName + } + } + } + """, + ) + + +@pytest.mark.asyncio +async def test_one_to_many(session_factory, benchmark, schema_provider): + session = session_factory() + schema = schema_provider() + + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + article_1 = Article(headline="Article_1") + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline="Article_2") + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline="Article_3") + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline="Article_4") + article_4.reporter = reporter_2 + session.add(article_4) + + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + await benchmark_query( + session, + benchmark, + schema, + """ + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """, + ) + + +@pytest.mark.asyncio +async def test_many_to_many(session_factory, benchmark, schema_provider): + session = session_factory() + schema = schema_provider() + reporter_1 = Reporter( + first_name="Reporter_1", + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name="Reporter_2", + ) + session.add(reporter_2) + + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") + + await benchmark_query( + session, + benchmark, + schema, + """ + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """, + ) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f38999d2..e62e07d2 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,107 +1,376 @@ import enum +import sys +from typing import Dict, Tuple, TypeVar, Union import pytest -from sqlalchemy import Column, Table, case, func, select, types +import sqlalchemy +import sqlalchemy_utils as sqa_utils +from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from sqlalchemy.sql.elements import Label -from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene from graphene.relay import Node -from graphene.types.datetime import DateTime -from graphene.types.json import JSONString - -from ..converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_relationship) -from ..fields import (UnsortedSQLAlchemyConnectionField, - default_connection_field_factory) -from ..registry import Registry -from ..types import SQLAlchemyObjectType -from .models import Article, Pet, Reporter - - -def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): - column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs) - graphene_type = convert_sqlalchemy_column(column) - assert isinstance(graphene_type, graphene_field) - field = ( - graphene_type - if isinstance(graphene_type, graphene.Field) - else graphene_type.Field() +from graphene.types.structures import Structure + +from ..converter import ( + convert_sqlalchemy_association_proxy, + convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship, + convert_sqlalchemy_type, + set_non_null_many_relationships, +) +from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory +from ..registry import Registry, get_global_registry +from ..types import ORMField, SQLAlchemyObjectType +from ..utils import is_sqlalchemy_version_less_than +from .models import ( + Article, + CompositeFullName, + CustomColumnModel, + Pet, + ProxiedReporter, + Reporter, + ShoppingCart, + ShoppingCartItem, +) +from .utils import wrap_select_func + + +def mock_resolver(): + pass + + +def get_field(sqlalchemy_type, **column_kwargs): + class Model(declarative_base()): + __tablename__ = "model" + id_ = Column(types.Integer, primary_key=True) + column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) + + column_prop = inspect(Model).column_attrs["column"] + return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) + + +def get_field_from_column(column_): + class Model(declarative_base()): + __tablename__ = "model" + id_ = Column(types.Integer, primary_key=True) + column = column_ + + column_prop = inspect(Model).column_attrs["column"] + return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) + + +def get_hybrid_property_type(prop_method): + class Model(declarative_base()): + __tablename__ = "model" + id_ = Column(types.Integer, primary_key=True) + prop = prop_method + + column_prop = inspect(Model).all_orm_descriptors["prop"] + return convert_sqlalchemy_hybrid_method( + column_prop, mock_resolver(), **ORMField().kwargs ) - assert field.description == "Custom Help Text" - return field -def assert_composite_conversion( - composite_class, composite_columns, graphene_field, registry, **kwargs -): - composite_column = composite( - composite_class, *composite_columns, doc="Custom Help Text", **kwargs - ) - graphene_type = convert_sqlalchemy_composite(composite_column, registry) - assert isinstance(graphene_type, graphene_field) - field = graphene_type.Field() - # SQLAlchemy currently does not persist the doc onto the column, even though - # the documentation says it does.... - # assert field.description == 'Custom Help Text' - return field +@pytest.fixture +def use_legacy_many_relationships(): + set_non_null_many_relationships(False) + try: + yield + finally: + set_non_null_many_relationships(True) -def test_should_unknown_sqlalchemy_field_raise_exception(): - re_err = "Don't know how to convert the SQLAlchemy field" - with pytest.raises(Exception, match=re_err): - convert_sqlalchemy_column(None) +def test_hybrid_prop_int(): + @hybrid_property + def prop_method() -> int: + return 42 + + assert get_hybrid_property_type(prop_method).type == graphene.Int + + +def test_hybrid_unknown_annotation(): + @hybrid_property + def hybrid_prop(self): + return "This should fail" + + with pytest.raises( + TypeError, + match=r"(.*)Please make sure to annotate the return type of the hybrid property or use the " + "type_ attribute of ORMField to set the type.(.*)", + ): + get_hybrid_property_type(hybrid_prop) + + +def test_hybrid_prop_no_type_annotation(): + @hybrid_property + def hybrid_prop(self) -> Tuple[str, str]: + return "This should Fail because", "we don't support Tuples in GQL" + + with pytest.raises( + TypeError, match=r"(.*)Don't know how to convert the SQLAlchemy field(.*)" + ): + get_hybrid_property_type(hybrid_prop) + + +def test_hybrid_invalid_forward_reference(): + class MyTypeNotInRegistry: + pass + + @hybrid_property + def hybrid_prop(self) -> "MyTypeNotInRegistry": + return MyTypeNotInRegistry() + + with pytest.raises( + TypeError, + match=r"(.*)Only forward references to other SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed.(.*)", + ): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_object_type(): + class MyObjectType(graphene.ObjectType): + string = graphene.String() + + @hybrid_property + def hybrid_prop(self) -> MyObjectType: + return MyObjectType() + + assert get_hybrid_property_type(hybrid_prop).type == MyObjectType + + +def test_hybrid_prop_scalar_type(): + @hybrid_property + def hybrid_prop(self) -> graphene.String: + return "This should work" + + assert get_hybrid_property_type(hybrid_prop).type == graphene.String + + +def test_hybrid_prop_not_mapped_to_graphene_type(): + @hybrid_property + def hybrid_prop(self) -> ShoppingCartItem: + return "This shouldn't work" + + with pytest.raises(TypeError, match=r"(.*)No model found in Registry for type(.*)"): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_mapped_to_graphene_type(): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + + @hybrid_property + def hybrid_prop(self) -> ShoppingCartItem: + return "Dummy return value" + + get_hybrid_property_type(hybrid_prop).type == ShoppingCartType + + +def test_hybrid_prop_forward_ref_not_mapped_to_graphene_type(): + @hybrid_property + def hybrid_prop(self) -> "ShoppingCartItem": + return "This shouldn't work" + + with pytest.raises( + TypeError, + match=r"(.*)No model found in Registry for forward reference for type(.*)", + ): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_forward_ref_mapped_to_graphene_type(): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + + @hybrid_property + def hybrid_prop(self) -> "ShoppingCartItem": + return "Dummy return value" + + get_hybrid_property_type(hybrid_prop).type == ShoppingCartType + + +def test_converter_replace_type_var(): + + T = TypeVar("T") + + replace_type_vars = {T: graphene.String} + + field_type = convert_sqlalchemy_type(T, replace_type_vars=replace_type_vars) + + assert field_type == graphene.String + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) +def test_hybrid_prop_scalar_union_310(): + @hybrid_property + def prop_method() -> int | str: + return "not allowed in gql schema" + + with pytest.raises( + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", + ): + get_hybrid_property_type(prop_method) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) +def test_hybrid_prop_scalar_union_and_optional_310(): + """Checks if the use of Optionals does not interfere with non-conform scalar return types""" + + @hybrid_property + def prop_method() -> int | None: + return 42 + + assert get_hybrid_property_type(prop_method).type == graphene.Int + + +def test_should_union_work(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + registry = reg + + @hybrid_property + def prop_method() -> Union[PetType, ShoppingCartType]: + return None + + @hybrid_property + def prop_method_2() -> Union[ShoppingCartType, PetType]: + return None + + field_type_1 = get_hybrid_property_type(prop_method).type + field_type_2 = get_hybrid_property_type(prop_method_2).type + + assert issubclass(field_type_1, graphene.Union) + assert field_type_1._meta.types == [PetType, ShoppingCartType] + assert field_type_1 is field_type_2 + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) +def test_should_union_work_310(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + registry = reg + @hybrid_property + def prop_method() -> PetType | ShoppingCartType: + return None -def test_should_date_convert_string(): - assert_column_conversion(types.Date(), graphene.String) + @hybrid_property + def prop_method_2() -> ShoppingCartType | PetType: + return None + field_type_1 = get_hybrid_property_type(prop_method).type + field_type_2 = get_hybrid_property_type(prop_method_2).type -def test_should_datetime_convert_string(): - assert_column_conversion(types.DateTime(), DateTime) + assert issubclass(field_type_1, graphene.Union) + assert field_type_1._meta.types == [PetType, ShoppingCartType] + assert field_type_1 is field_type_2 -def test_should_time_convert_string(): - assert_column_conversion(types.Time(), graphene.String) +def test_should_unknown_type_raise_error(): + with pytest.raises(Exception): + converted_type = convert_sqlalchemy_type(ZeroDivisionError) # noqa + + +def test_should_datetime_convert_datetime(): + assert get_field(types.DateTime()).type == graphene.DateTime + + +def test_should_time_convert_time(): + assert get_field(types.Time()).type == graphene.Time + + +def test_should_date_convert_date(): + assert get_field(types.Date()).type == graphene.Date def test_should_string_convert_string(): - assert_column_conversion(types.String(), graphene.String) + assert get_field(types.String()).type == graphene.String def test_should_text_convert_string(): - assert_column_conversion(types.Text(), graphene.String) + assert get_field(types.Text()).type == graphene.String def test_should_unicode_convert_string(): - assert_column_conversion(types.Unicode(), graphene.String) + assert get_field(types.Unicode()).type == graphene.String def test_should_unicodetext_convert_string(): - assert_column_conversion(types.UnicodeText(), graphene.String) + assert get_field(types.UnicodeText()).type == graphene.String + + +def test_should_tsvector_convert_string(): + assert get_field(sqa_utils.TSVectorType()).type == graphene.String + + +def test_should_email_convert_string(): + assert get_field(sqa_utils.EmailType()).type == graphene.String + + +def test_should_URL_convert_string(): + assert get_field(sqa_utils.URLType()).type == graphene.String + + +def test_should_IPaddress_convert_string(): + assert get_field(sqa_utils.IPAddressType()).type == graphene.String + + +def test_should_inet_convert_string(): + assert get_field(postgresql.INET()).type == graphene.String + + +def test_should_cidr_convert_string(): + assert get_field(postgresql.CIDR()).type == graphene.String def test_should_enum_convert_enum(): - field = assert_column_conversion( - types.Enum(enum.Enum("TwoNumbers", ("one", "two"))), graphene.Field - ) + field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two")))) field_type = field.type() assert isinstance(field_type, graphene.Enum) + assert field_type._meta.name == "TwoNumbers" assert hasattr(field_type, "ONE") assert not hasattr(field_type, "one") assert hasattr(field_type, "TWO") assert not hasattr(field_type, "two") - field = assert_column_conversion( - types.Enum("one", "two", name="two_numbers"), graphene.Field - ) + field = get_field(types.Enum("one", "two", name="two_numbers")) field_type = field.type() - assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) + assert field_type._meta.name == "TwoNumbers" assert hasattr(field_type, "ONE") assert not hasattr(field_type, "one") assert hasattr(field_type, "TWO") @@ -109,95 +378,149 @@ def test_should_enum_convert_enum(): def test_should_not_enum_convert_enum_without_name(): - field = assert_column_conversion( - types.Enum("one", "two"), graphene.Field - ) + field = get_field(types.Enum("one", "two")) re_err = r"No type name specified for Enum\('one', 'two'\)" with pytest.raises(TypeError, match=re_err): field.type() def test_should_small_integer_convert_int(): - assert_column_conversion(types.SmallInteger(), graphene.Int) + assert get_field(types.SmallInteger()).type == graphene.Int def test_should_big_integer_convert_int(): - assert_column_conversion(types.BigInteger(), graphene.Float) + assert get_field(types.BigInteger()).type == graphene.Float def test_should_integer_convert_int(): - assert_column_conversion(types.Integer(), graphene.Int) + assert get_field(types.Integer()).type == graphene.Int -def test_should_integer_convert_id(): - assert_column_conversion(types.Integer(), graphene.ID, primary_key=True) +def test_should_primary_integer_convert_id(): + assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull( + graphene.ID + ) def test_should_boolean_convert_boolean(): - assert_column_conversion(types.Boolean(), graphene.Boolean) + assert get_field(types.Boolean()).type == graphene.Boolean def test_should_float_convert_float(): - assert_column_conversion(types.Float(), graphene.Float) + assert get_field(types.Float()).type == graphene.Float def test_should_numeric_convert_float(): - assert_column_conversion(types.Numeric(), graphene.Float) + assert get_field(types.Numeric()).type == graphene.Float + +def test_should_choice_convert_enum(): + field = get_field(sqa_utils.ChoiceType([("es", "Spanish"), ("en", "English")])) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" -def test_should_label_convert_string(): - label = Label("label_test", case([], else_="foo"), type_=types.Unicode()) - graphene_type = convert_sqlalchemy_column(label) - assert isinstance(graphene_type, graphene.String) +def test_should_enum_choice_convert_enum(): + class TestEnum(enum.Enum): + es = "Spanish" + en = "English" -def test_should_label_convert_int(): - label = Label("int_label_test", case([], else_="foo"), type_=types.Integer()) - graphene_type = convert_sqlalchemy_column(label) - assert isinstance(graphene_type, graphene.Int) + field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" -def test_should_choice_convert_enum(): - TYPES = [(u"es", u"Spanish"), (u"en", u"English")] - column = Column(ChoiceType(TYPES), doc="Language", name="language") - Base = declarative_base() +def test_choice_enum_column_key_name_issue_301(): + """ + Verifies that the sort enum name is generated from the column key instead of the name, + in case the column has an invalid enum name. See #330 + """ + + class TestEnum(enum.Enum): + es = "Spanish" + en = "English" + + testChoice = Column( + "% descuento1", + sqa_utils.ChoiceType(TestEnum, impl=types.String()), + key="descuento1", + ) + field = get_field_from_column(testChoice) - Table("translatedmodel", Base.metadata, column) - graphene_type = convert_sqlalchemy_column(column) + graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) - assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE" - assert graphene_type._meta.description == "Language" + assert graphene_type._meta.name == "MODEL_DESCUENTO1" assert graphene_type._meta.enum.__members__["es"].value == "Spanish" assert graphene_type._meta.enum.__members__["en"].value == "English" -def test_should_columproperty_convert(): +def test_should_intenum_choice_convert_enum(): + class TestEnum(enum.IntEnum): + one = 1 + two = 2 - Base = declarative_base() + field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["one"].value == 1 + assert graphene_type._meta.enum.__members__["two"].value == 2 - class Test(Base): - __tablename__ = "test" - id = Column(types.Integer, primary_key=True) - column = column_property( - select([func.sum(func.cast(id, types.Integer))]).where(id == 1) + +def test_should_columproperty_convert(): + field = get_field_from_column( + column_property( + wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1) ) + ) - graphene_type = convert_sqlalchemy_column(Test.column) - assert not graphene_type.kwargs["required"] + assert field.type == graphene.Int def test_should_scalar_list_convert_list(): - assert_column_conversion(ScalarListType(), graphene.List) + field = get_field(sqa_utils.ScalarListType()) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.String def test_should_jsontype_convert_jsonstring(): - assert_column_conversion(JSONType(), JSONString) + assert get_field(sqa_utils.JSONType()).type == graphene.JSONString + assert get_field(types.JSON).type == graphene.JSONString + + +@pytest.mark.skipif( + (not is_sqlalchemy_version_less_than("2.0.0b1")), + reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy", +) +def test_should_variant_int_convert_int(): + assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int + + +@pytest.mark.skipif( + (not is_sqlalchemy_version_less_than("2.0.0b1")), + reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy", +) +def test_should_variant_string_convert_string(): + assert get_field(types.Variant(types.String(), {})).type == graphene.String def test_should_manytomany_convert_connectionorlist(): - registry = Registry() + class A(SQLAlchemyObjectType): + class Meta: + model = Article + dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, registry, default_connection_field_factory + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -209,8 +532,36 @@ class Meta: model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) + # field should be [A!]! + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert isinstance(graphene_type.type, graphene.NonNull) + assert isinstance(graphene_type.type.of_type, graphene.List) + assert isinstance(graphene_type.type.of_type.of_type, graphene.NonNull) + assert graphene_type.type.of_type.of_type.of_type == A + + +@pytest.mark.usefixtures("use_legacy_many_relationships") +def test_should_manytomany_convert_connectionorlist_list_legacy(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", + ) + # field should be [A] assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -225,16 +576,27 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) def test_should_manytoone_convert_connectionorlist(): - registry = Registry() + class A(SQLAlchemyObjectType): + class Meta: + model = Article + dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, registry, default_connection_field_factory + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -246,7 +608,11 @@ class Meta: model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A._meta.registry, default_connection_field_factory + Article.reporter.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -261,7 +627,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A._meta.registry, default_connection_field_factory + Article.reporter.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -277,8 +647,10 @@ class Meta: dynamic_field = convert_sqlalchemy_relationship( Reporter.favorite_article.property, - A._meta.registry, + A, default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -286,17 +658,77 @@ class Meta: assert graphene_type.type == A +def test_should_convert_association_proxy(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + + field = convert_sqlalchemy_association_proxy( + Reporter, + Reporter.headlines, + ReporterType, + get_global_registry(), + default_connection_field_factory, + True, + mock_resolver, + ) + assert isinstance(field, graphene.Dynamic) + assert isinstance(field.get_type().type, graphene.List) + assert field.get_type().type.of_type == graphene.String + + dynamic_field = convert_sqlalchemy_association_proxy( + Article, + Article.recommended_reads, + ArticleType, + get_global_registry(), + default_connection_field_factory, + True, + mock_resolver, + ) + dynamic_field_type = dynamic_field.get_type().type + assert isinstance(dynamic_field, graphene.Dynamic) + assert isinstance(dynamic_field_type, graphene.NonNull) + assert isinstance(dynamic_field_type.of_type, graphene.List) + assert isinstance(dynamic_field_type.of_type.of_type, graphene.NonNull) + assert dynamic_field_type.of_type.of_type.of_type == ArticleType + + +def test_should_throw_error_association_proxy_unsupported_target(): + class ProxiedReporterType(SQLAlchemyObjectType): + class Meta: + model = ProxiedReporter + + field = convert_sqlalchemy_association_proxy( + ProxiedReporter, + ProxiedReporter.composite_prop, + ProxiedReporterType, + get_global_registry(), + default_connection_field_factory, + True, + mock_resolver, + ) + + with pytest.raises(TypeError): + field.get_type() + + def test_should_postgresql_uuid_convert(): - assert_column_conversion(postgresql.UUID(), graphene.String) + assert get_field(postgresql.UUID()).type == graphene.UUID + + +def test_should_sqlalchemy_utils_uuid_convert(): + assert get_field(sqa_utils.UUIDType()).type == graphene.UUID def test_should_postgresql_enum_convert(): - field = assert_column_conversion( - postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field - ) + field = get_field(postgresql.ENUM("one", "two", name="two_numbers")) field_type = field.type() - assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) + assert field_type._meta.name == "TwoNumbers" assert hasattr(field_type, "ONE") assert not hasattr(field_type, "one") assert hasattr(field_type, "TWO") @@ -304,9 +736,8 @@ def test_should_postgresql_enum_convert(): def test_should_postgresql_py_enum_convert(): - field = assert_column_conversion( - postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), - graphene.Field, + field = get_field( + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers") ) field_type = field.type() assert field_type._meta.name == "TwoNumbers" @@ -318,55 +749,222 @@ def test_should_postgresql_py_enum_convert(): def test_should_postgresql_array_convert(): - assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List) + field = get_field(postgresql.ARRAY(types.Integer)) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.Int + + +def test_should_array_convert(): + field = get_field(types.ARRAY(types.Integer)) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.Int + + +def test_should_2d_array_convert(): + field = get_field(types.ARRAY(types.Integer, dimensions=2)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert field.type.of_type.of_type == graphene.Int + + +def test_should_3d_array_convert(): + field = get_field(types.ARRAY(types.Integer, dimensions=3)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type, graphene.List) + assert field.type.of_type.of_type.of_type == graphene.Int def test_should_postgresql_json_convert(): - assert_column_conversion(postgresql.JSON(), JSONString) + assert get_field(postgresql.JSON()).type == graphene.JSONString def test_should_postgresql_jsonb_convert(): - assert_column_conversion(postgresql.JSONB(), JSONString) + assert get_field(postgresql.JSONB()).type == graphene.JSONString def test_should_postgresql_hstore_convert(): - assert_column_conversion(postgresql.HSTORE(), JSONString) + assert get_field(postgresql.HSTORE()).type == graphene.JSONString def test_should_composite_convert(): + registry = Registry() + class CompositeClass: def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 - registry = Registry() - @convert_sqlalchemy_composite.register(CompositeClass, registry) def convert_composite_class(composite, registry): return graphene.String(description=composite.doc) - assert_composite_conversion( - CompositeClass, - (Column(types.Unicode(50)), Column(types.Unicode(50))), - graphene.String, + field = convert_sqlalchemy_composite( + composite( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + doc="Custom Help Text", + ), registry, + mock_resolver, ) + assert isinstance(field, graphene.String) def test_should_unknown_sqlalchemy_composite_raise_exception(): - registry = Registry() + class CompositeClass: + def __init__(self, col1, col2): + self.col1 = col1 + self.col2 = col2 re_err = "Don't know how to convert the composite field" with pytest.raises(Exception, match=re_err): + convert_sqlalchemy_composite( + composite( + CompositeFullName, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + ), + Registry(), + mock_resolver, + ) - class CompositeClass(object): - def __init__(self, col1, col2): - self.col1 = col1 - self.col2 = col2 - assert_composite_conversion( - CompositeClass, - (Column(types.Unicode(50)), Column(types.Unicode(50))), - graphene.String, - registry, +def test_raise_exception_unkown_column_type(): + with pytest.raises( + Exception, + match="Don't know how to convert the SQLAlchemy field customcolumnmodel.custom_col", + ): + + class A(SQLAlchemyObjectType): + class Meta: + model = CustomColumnModel + + +def test_prioritize_orm_field_unkown_column_type(): + class A(SQLAlchemyObjectType): + class Meta: + model = CustomColumnModel + + custom_col = ORMField(type_=graphene.Int) + + assert A._meta.fields["custom_col"].type == graphene.Int + + +def test_match_supertype_from_mro_correct_order(): + """ + BigInt and Integer are both superclasses of BIGINT, but a custom converter exists for BigInt that maps to Float. + We expect the correct MRO order to be used and conversion by the nearest match. BIGINT should be converted to Float, + just like BigInt, not to Int like integer which is further up in the MRO. + """ + + class BIGINT(sqlalchemy.types.BigInteger): + pass + + field = get_field_from_column(Column(BIGINT)) + + assert field.type == graphene.Float + + +def test_sqlalchemy_hybrid_property_type_inference(): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + interfaces = (Node,) + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + interfaces = (Node,) + + ####################################################### + # Check ShoppingCartItem's Properties and Return Types + ####################################################### + + shopping_cart_item_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { + "hybrid_prop_shopping_cart": graphene.List(ShoppingCartType) + } + + assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted( + [ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_item_expected_types.keys(), + ] + ) + + for ( + hybrid_prop_name, + hybrid_prop_expected_return_type, + ) in shopping_cart_item_expected_types.items(): + hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert (hybrid_prop_name, str(hybrid_prop_field.type)) == ( + hybrid_prop_name, + str(hybrid_prop_expected_return_type), + ) + assert ( + hybrid_prop_field.description is None + ) # "doc" is ignored by hybrid property + + ################################################### + # Check ShoppingCart's Properties and Return Types + ################################################### + + shopping_cart_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { + # Basic types + "hybrid_prop_str": graphene.String, + "hybrid_prop_int": graphene.Int, + "hybrid_prop_float": graphene.Float, + "hybrid_prop_bool": graphene.Boolean, + "hybrid_prop_decimal": graphene.String, # Decimals should be serialized Strings + "hybrid_prop_date": graphene.Date, + "hybrid_prop_time": graphene.Time, + "hybrid_prop_datetime": graphene.DateTime, + # Lists and Nested Lists + "hybrid_prop_list_int": graphene.List(graphene.Int), + "hybrid_prop_list_date": graphene.List(graphene.Date), + "hybrid_prop_nested_list_int": graphene.List(graphene.List(graphene.Int)), + "hybrid_prop_deeply_nested_list_int": graphene.List( + graphene.List(graphene.List(graphene.Int)) + ), + "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_first_shopping_cart_item_expression": ShoppingCartItemType, + "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), + # Self Referential List + "hybrid_prop_self_referential": ShoppingCartType, + "hybrid_prop_self_referential_list": graphene.List(ShoppingCartType), + # Optionals + "hybrid_prop_optional_self_referential": ShoppingCartType, + # UUIDs + "hybrid_prop_uuid": graphene.UUID, + "hybrid_prop_optional_uuid": graphene.UUID, + "hybrid_prop_uuid_list": graphene.List(graphene.UUID), + } + + assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted( + [ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_expected_types.keys(), + ] + ) + + for ( + hybrid_prop_name, + hybrid_prop_expected_return_type, + ) in shopping_cart_expected_types.items(): + hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert (hybrid_prop_name, str(hybrid_prop_field.type)) == ( + hybrid_prop_name, + str(hybrid_prop_expected_return_type), ) + assert ( + hybrid_prop_field.description is None + ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index ca376964..3de6904b 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -54,7 +54,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): assert [ (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() - ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): @@ -65,7 +65,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): assert [ (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() - ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): @@ -80,36 +80,38 @@ class PetType(SQLAlchemyObjectType): class Meta: model = Pet - enum = enum_for_field(PetType, 'pet_kind') + enum = enum_for_field(PetType, "pet_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "PetKind" assert [ - (key, value.value) - for key, value in enum._meta.enum.__members__.items() - ] == [("CAT", 'cat'), ("DOG", 'dog')] - enum2 = enum_for_field(PetType, 'pet_kind') + (key, value.value) for key, value in enum._meta.enum.__members__.items() + ] == [ + ("CAT", "cat"), + ("DOG", "dog"), + ] + enum2 = enum_for_field(PetType, "pet_kind") assert enum2 is enum - enum2 = PetType.enum_for_field('pet_kind') + enum2 = PetType.enum_for_field("pet_kind") assert enum2 is enum - enum = enum_for_field(PetType, 'hair_kind') + enum = enum_for_field(PetType, "hair_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "HairKind" assert enum._meta.enum is HairKind - enum2 = PetType.enum_for_field('hair_kind') + enum2 = PetType.enum_for_field("hair_kind") assert enum2 is enum re_err = r"Cannot get PetType\.other_kind" with pytest.raises(TypeError, match=re_err): - enum_for_field(PetType, 'other_kind') + enum_for_field(PetType, "other_kind") with pytest.raises(TypeError, match=re_err): - PetType.enum_for_field('other_kind') + PetType.enum_for_field("other_kind") re_err = r"PetType\.name does not map to enum column" with pytest.raises(TypeError, match=re_err): - enum_for_field(PetType, 'name') + enum_for_field(PetType, "name") with pytest.raises(TypeError, match=re_err): - PetType.enum_for_field('name') + PetType.enum_for_field("name") re_err = r"Expected a field name, but got: None" with pytest.raises(TypeError, match=re_err): @@ -119,4 +121,4 @@ class Meta: re_err = "Expected SQLAlchemyObjectType, but got: None" with pytest.raises(TypeError, match=re_err): - enum_for_field(None, 'other_kind') + enum_for_field(None, "other_kind") diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 0f8738f0..9fed146d 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,8 +1,10 @@ import pytest +from promise import Promise -from graphene.relay import Connection +from graphene import NonNull, ObjectType +from graphene.relay import Connection, Node -from ..fields import SQLAlchemyConnectionField +from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from .models import Editor as EditorModel from .models import Pet as PetModel @@ -11,6 +13,7 @@ class Pet(SQLAlchemyObjectType): class Meta: model = PetModel + interfaces = (Node,) class Editor(SQLAlchemyObjectType): @@ -18,27 +21,73 @@ class Meta: model = EditorModel -class PetConn(Connection): - class Meta: - node = Pet +## +# SQLAlchemyConnectionField +## + + +def test_nonnull_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(NonNull(Pet.connection)) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + +def test_required_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(Pet.connection, required=True) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + +def test_promise_connection_resolver(): + def resolver(_obj, _info): + return Promise.resolve([]) + + result = UnsortedSQLAlchemyConnectionField.connection_resolver( + resolver, Pet.connection, Pet, None, None + ) + assert isinstance(result, Promise) + + +def test_type_assert_sqlalchemy_object_type(): + with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"): + SQLAlchemyConnectionField(ObjectType).type + + +def test_type_assert_object_has_connection(): + with pytest.raises(AssertionError, match="doesn't have a connection"): + SQLAlchemyConnectionField(Editor).type + + +## +# UnsortedSQLAlchemyConnectionField +## + + +def test_unsorted_connection_field_removes_sort_arg_if_passed(): + editor = UnsortedSQLAlchemyConnectionField( + Editor.connection, sort=Editor.sort_argument(has_default=True) + ) + assert "sort" not in editor.args def test_sort_added_by_default(): - field = SQLAlchemyConnectionField(PetConn) + field = SQLAlchemyConnectionField(Pet.connection) assert "sort" in field.args assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - field = SQLAlchemyConnectionField(PetConn, sort=None) + field = SQLAlchemyConnectionField(Pet.connection, sort=None) assert "sort" not in field.args def test_custom_sort(): - field = SQLAlchemyConnectionField(PetConn, sort=Editor.sort_argument()) + field = SQLAlchemyConnectionField(Pet.connection, sort=Editor.sort_argument()) assert field.args["sort"] == Editor.sort_argument() -def test_init_raises(): +def test_sort_init_raises(): with pytest.raises(TypeError, match="Cannot create sort"): SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py new file mode 100644 index 00000000..87bbceae --- /dev/null +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -0,0 +1,1228 @@ +import pytest +from sqlalchemy.sql.operators import is_ + +import graphene +from graphene import Connection, relay + +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType +from .models import ( + Article, + Editor, + HairKind, + Image, + Pet, + Reader, + Reporter, + ShoppingCart, + ShoppingCartItem, + Tag, +) +from .utils import eventually_await_session, to_std_dicts + +# TODO test that generated schema is correct for all examples with: +# with open('schema.gql', 'w') as fp: +# fp.write(str(schema)) + + +def assert_and_raise_result(result, expected): + if result.errors: + for error in result.errors: + raise error + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +async def add_test_data(session): + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") + session.add(reporter) + + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT, legs=4) + pet.reporter = reporter + session.add(pet) + + pet = Pet(name="Snoopy", pet_kind="dog", hair_kind=HairKind.SHORT, legs=3) + pet.reporter = reporter + session.add(pet) + + reporter = Reporter(first_name="John", last_name="Woe", favorite_pet_kind="cat") + session.add(reporter) + + article = Article(headline="Hi!") + article.reporter = reporter + session.add(article) + + article = Article(headline="Hello!") + article.reporter = reporter + session.add(article) + + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") + session.add(reporter) + + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) + pet.reporter = reporter + session.add(pet) + + editor = Editor(name="Jack") + session.add(editor) + + await eventually_await_session(session, "commit") + + +def create_schema(session): + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + + class ImageType(SQLAlchemyObjectType): + class Meta: + model = Image + name = "Image" + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + class TagType(SQLAlchemyObjectType): + class Meta: + model = Tag + name = "Tag" + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection) + images = SQLAlchemyConnectionField(ImageType.connection) + readers = SQLAlchemyConnectionField(ReaderType.connection) + reporters = SQLAlchemyConnectionField(ReporterType.connection) + pets = SQLAlchemyConnectionField(PetType.connection) + tags = SQLAlchemyConnectionField(TagType.connection) + + return Query + + +# Test a simple example of filtering +@pytest.mark.asyncio +async def test_filter_simple(session): + await add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filter: {lastName: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_filter_alias(session): + """ + Test aliasing of column names in the type + """ + await add_test_data(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + lastNameAlias = ORMField(model_attr="last_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = SQLAlchemyConnectionField(ReporterType.connection) + + query = """ + query { + reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a custom filter type +@pytest.mark.asyncio +async def test_filter_custom_type(session): + await add_test_data(session) + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection + + legs = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + query = """ + query { + pets (filter: { + legs: {divisibleBy: 2} + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": { + "edges": [{"node": {"name": "Garfield"}}, {"node": {"name": "Lassie"}}] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test filtering on enums +@pytest.mark.asyncio +async def test_filter_enum(session): + await add_test_data(session) + + Query = create_schema(session) + + # test sqlalchemy enum + query = """ + query { + reporters (filter: { + favoritePetKind: {eq: DOG} + } + ) { + edges { + node { + firstName + lastName + favoritePetKind + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + } + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test Python enum and sqlalchemy enum + query = """ + query { + pets (filter: { + and: [ + { hairKind: {eq: LONG} }, + { petKind: {eq: DOG} } + ]}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Lassie"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a 1:1 relationship +@pytest.mark.asyncio +async def test_filter_relationship_one_to_one(session): + article = Article(headline="Hi!") + image = Image(external_id=1, description="A beautiful image.") + article.image = image + session.add(article) + session.add(image) + await eventually_await_session(session, "commit") + + Query = create_schema(session) + + query = """ + query { + articles (filter: { + image: {description: {eq: "A beautiful image."}} + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a 1:n relationship +@pytest.mark.asyncio +async def test_filter_relationship_one_to_many(session): + await add_test_data(session) + Query = create_schema(session) + + # test contains + query = """ + query { + reporters (filter: { + articles: { + contains: [{headline: {eq: "Hi!"}}], + } + }) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # TODO test containsExactly + # # test containsExactly + # query = """ + # query { + # reporters (filter: { + # articles: { + # containsExactly: [ + # {headline: {eq: "Hi!"}} + # {headline: {eq: "Hello!"}} + # ] + # } + # }) { + # edges { + # node { + # firstName + # lastName + # } + # } + # } + # } + # """ + # expected = { + # "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} + # } + # schema = graphene.Schema(query=Query) + # result = await schema.execute_async(query, context_value={"session": session}) + # assert_and_raise_result(result, expected) + + +async def add_n2m_test_data(session): + # create objects + reader1 = Reader(name="Ada") + reader2 = Reader(name="Bip") + article1 = Article(headline="Article! Look!") + article2 = Article(headline="Woah! Another!") + tag1 = Tag(name="sensational") + tag2 = Tag(name="eye-grabbing") + image1 = Image(description="article 1") + image2 = Image(description="article 2") + + # set relationships + article1.tags = [tag1] + article2.tags = [tag1, tag2] + article1.image = image1 + article2.image = image2 + reader1.articles = [article1] + reader2.articles = [article1, article2] + + # save + session.add(image1) + session.add(image2) + session.add(tag1) + session.add(tag2) + session.add(article1) + session.add(article2) + session.add(reader1) + session.add(reader2) + await eventually_await_session(session, "commit") + + +# Test n:m relationship contains +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains(session): + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Article! Look!"}}, + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_with_and(session): + """ + This test is necessary to ensure we don't accidentally turn and-contains filter + into or-contains filters due to incorrect aliasing of the joined table. + """ + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [{ + and: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + { name: { eq: "eye-grabbing" } }, + ] + + } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m relationship containsExactly +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_exactly(session): + raise NotImplementedError + await add_n2m_test_data(session) + Query = create_schema(session) + + # test containsExactly 1 + query = """ + query { + articles (filter: { + tags: { + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test containsExactly 2 + query = """ + query { + articles (filter: { + tags: { + containsExactly: [ + { name: { eq: "sensational" } } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Article! Look!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + containsExactly: [ + { headline: { eq: "Article! Look!" } }, + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "eye-grabbing"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m relationship both contains and containsExactly +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): + raise NotImplementedError + await add_n2m_test_data(session) + Query = create_schema(session) + + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m nested relationship +# TODO add containsExactly +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_nested(session): + await add_n2m_test_data(session) + Query = create_schema(session) + + # test readers->articles relationship + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test nested readers->articles->tags + query = """ + query { + readers (filter: { + articles: { + contains: [ + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test nested reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { + readers: { + contains: [ + { name: { eq: "Ada" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "sensational"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test filter on both levels of nesting + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "and" +@pytest.mark.asyncio +async def test_filter_logic_and(session): + await add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + and: [ + { firstName: { eq: "John" } }, + { favoritePetKind: { eq: CAT } }, + ] + }) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": { + "edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "or" +@pytest.mark.asyncio +async def test_filter_logic_or(session): + await add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + or: [ + { lastName: { eq: "Woe" } }, + { favoritePetKind: { eq: DOG } }, + ] + }) { + edges { + node { + firstName + lastName + favoritePetKind + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + { + "node": { + "firstName": "John", + "lastName": "Woe", + "favoritePetKind": "CAT", + } + }, + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + }, + ] + } + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "and" and "or" together +@pytest.mark.asyncio +async def test_filter_logic_and_or(session): + await add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + and: [ + { firstName: { eq: "John" } }, + { + or: [ + { lastName: { eq: "Doe" } }, + # TODO get enums working for filters + # { favoritePetKind: { eq: "cat" } }, + ] + } + ] + }) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + {"node": {"firstName": "John"}}, + # {"node": {"firstName": "Jane"}}, + ], + } + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +async def add_hybrid_prop_test_data(session): + cart = ShoppingCart() + session.add(cart) + await eventually_await_session(session, "commit") + + +def create_hybrid_prop_schema(session): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + name = "ShoppingCartItem" + interfaces = (relay.Node,) + connection_class = Connection + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + name = "ShoppingCart" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + items = SQLAlchemyConnectionField(ShoppingCartItemType.connection) + carts = SQLAlchemyConnectionField(ShoppingCartType.connection) + + return Query + + +# Test filtering over and returning hybrid_property +@pytest.mark.asyncio +async def test_filter_hybrid_property(session): + await add_hybrid_prop_test_data(session) + Query = create_hybrid_prop_schema(session) + + # test hybrid_prop_int + query = """ + query { + carts (filter: {hybridPropInt: {eq: 42}}) { + edges { + node { + hybridPropInt + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropInt": 42}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test hybrid_prop_float + query = """ + query { + carts (filter: {hybridPropFloat: {gt: 42}}) { + edges { + node { + hybridPropFloat + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropFloat": 42.3}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test hybrid_prop different model without expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItem { + id + } + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + + # test hybrid_prop different model with expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItemExpression { + id + } + } + } + } + } + """ + + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + + # test hybrid_prop list of models + query = """ + query { + carts { + edges { + node { + hybridPropShoppingCartItemList { + id + } + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + assert ( + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + ) + + +# Test edge cases to improve test coverage +@pytest.mark.asyncio +async def test_filter_edge_cases(session): + await add_test_data(session) + + # test disabling filtering + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection, filter=None) + + schema = graphene.Schema(query=Query) + assert not hasattr(schema, "ArticleTypeFilter") + + +# Test additional filter types to improve test coverage +@pytest.mark.asyncio +async def test_additional_filters(session): + await add_test_data(session) + Query = create_schema(session) + + # test n_eq and not_in filters + query = """ + query { + reporters (filter: {firstName: {nEq: "Jane"}, lastName: {notIn: "Doe"}}) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test gt, lt, gte, and lte filters + query = """ + query { + pets (filter: {legs: {gt: 2, lt: 4, gte: 3, lte: 3}}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Snoopy"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_do_not_create_filters(): + class WithoutFilters(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, _meta=None, **options): + super().__init_subclass_with_meta__( + _meta=_meta, create_filters=False, **options + ) + + class PetType(WithoutFilters): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + schema = graphene.Schema(query=Query) + + assert "filter" not in str(schema).lower() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 5279bd87..168a82f9 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,44 +1,57 @@ -import graphene -from graphene.relay import Connection, Node - -from ..fields import SQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType -from .models import Article, Editor, HairKind, Pet, Reporter +from datetime import date +import pytest +from sqlalchemy import select -def to_std_dicts(value): - """Convert nested ordered dicts to normal dicts for better comparison.""" - if isinstance(value, dict): - return {k: to_std_dicts(v) for k, v in value.items()} - elif isinstance(value, list): - return [to_std_dicts(v) for v in value] - else: - return value - +import graphene +from graphene.relay import Node -def add_test_data(session): - reporter = Reporter( - first_name='John', last_name='Doe', favorite_pet_kind='cat') +from ..converter import convert_sqlalchemy_composite +from ..fields import SQLAlchemyConnectionField +from ..types import ORMField, SQLAlchemyInterface, SQLAlchemyObjectType +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session +from .models import ( + Article, + CompositeFullName, + Editor, + Employee, + HairKind, + Person, + Pet, + Reporter, +) +from .utils import eventually_await_session, to_std_dicts + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession + + +async def add_test_data(session): + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) - pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) session.add(pet) pet.reporters.append(reporter) - article = Article(headline='Hi!') + article = Article(headline="Hi!") article.reporter = reporter session.add(article) - reporter = Reporter( - first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") session.add(reporter) - pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) pet.reporters.append(reporter) session.add(pet) editor = Editor(name="Jack") session.add(editor) - session.commit() + await eventually_await_session(session, "commit") -def test_should_query_well(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_fields(session): + await add_test_data(session) + + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.String() class ReporterType(SQLAlchemyObjectType): class Meta: @@ -48,18 +61,26 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) query = """ - query ReporterQuery { + query { reporter { firstName - lastName - email + columnProp + hybridProp + compositeProp + headlines } reporters { firstName @@ -67,18 +88,25 @@ def resolve_reporters(self, _info): } """ expected = { - "reporter": {"firstName": "John", "lastName": "Doe", "email": None}, + "reporter": { + "firstName": "John", + "hybridProp": "John", + "columnProp": 2, + "compositeProp": "John Doe", + "headlines": ["Hi!"], + }, "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_should_query_node(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_node_sync(session): + await add_test_data(session) class ReporterNode(SQLAlchemyObjectType): class Meta: @@ -94,27 +122,112 @@ class Meta: model = Article interfaces = (Node,) - class ArticleConnection(Connection): + class Query(graphene.ObjectType): + node = Node.Field() + reporter = graphene.Field(ReporterNode) + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) + + def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + + return session.query(Reporter).first() + + query = """ + query { + reporter { + id + firstName + articles { + edges { + node { + headline + } + } + } + } + allArticles { + edges { + node { + headline + } + } + } + myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { + id + ... on ReporterNode { + firstName + } + ... on ArticleNode { + headline + } + } + } + """ + expected = { + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "John", + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, + "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, + } + schema = graphene.Schema(query=Query) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + result = schema.execute(query, context_value={"session": session}) + assert result.errors + else: + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +@pytest.mark.asyncio +async def test_query_node_async(session): + await add_test_data(session) + + class ReporterNode(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_node(cls, info, id): + return Reporter(id=2, first_name="Cookie Monster") + + class ArticleNode(SQLAlchemyObjectType): class Meta: - node = ArticleNode + model = Article + interfaces = (Node,) class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) - article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleConnection) + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) def resolve_reporter(self, _info): - return session.query(Reporter).first() + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): - def resolve_article(self, _info): - return session.query(Article).first() + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + + return session.query(Reporter).first() query = """ - query ReporterQuery { + query { reporter { id - firstName, + firstName articles { edges { node { @@ -122,8 +235,6 @@ def resolve_article(self, _info): } } } - lastName, - email } allArticles { edges { @@ -147,38 +258,104 @@ def resolve_article(self, _info): "reporter": { "id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John", - "lastName": "Doe", - "email": None, "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, }, "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_should_custom_identifier(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_orm_field(session): + await add_test_data(session) - class EditorNode(SQLAlchemyObjectType): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.String() + + class ReporterType(SQLAlchemyObjectType): class Meta: - model = Editor + model = Reporter interfaces = (Node,) - class EditorConnection(Connection): + first_name_v2 = ORMField(model_attr="first_name") + hybrid_prop_v2 = ORMField(model_attr="hybrid_prop") + column_prop_v2 = ORMField(model_attr="column_prop") + composite_prop = ORMField() + favorite_article_v2 = ORMField(model_attr="favorite_article") + articles_v2 = ORMField(model_attr="articles") + + class ArticleType(SQLAlchemyObjectType): class Meta: - node = EditorNode + model = Article + interfaces = (Node,) + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).first() + return session.query(Reporter).first() + + query = """ + query { + reporter { + firstNameV2 + hybridPropV2 + columnPropV2 + compositeProp + favoriteArticleV2 { + headline + } + articlesV2(first: 1) { + edges { + node { + headline + } + } + } + } + } + """ + expected = { + "reporter": { + "firstNameV2": "John", + "hybridPropV2": "John", + "columnPropV2": 2, + "compositeProp": "John Doe", + "favoriteArticleV2": {"headline": "Hi!"}, + "articlesV2": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +@pytest.mark.asyncio +async def test_custom_identifier(session): + await add_test_data(session) + + class EditorNode(SQLAlchemyObjectType): + class Meta: + model = Editor + interfaces = (Node,) class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorConnection) + all_editors = SQLAlchemyConnectionField(EditorNode.connection) query = """ - query EditorQuery { + query { allEditors { edges { node { @@ -200,14 +377,15 @@ class Query(graphene.ObjectType): } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_should_mutate_well(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_mutation(session, session_factory): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -220,8 +398,11 @@ class Meta: interfaces = (Node,) @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name="Cookie Monster") + async def get_node(cls, id, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() + return session.query(Reporter).first() class ArticleNode(SQLAlchemyObjectType): class Meta: @@ -236,11 +417,14 @@ class Arguments: ok = graphene.Boolean() article = graphene.Field(ArticleNode) - def mutate(self, info, headline, reporter_id): + async def mutate(self, info, headline, reporter_id): + reporter = await ReporterNode.get_node(reporter_id, info) new_article = Article(headline=headline, reporter_id=reporter_id) + reporter.articles = [*reporter.articles, new_article] + session = get_session(info.context) + session.add(reporter) - session.add(new_article) - session.commit() + await eventually_await_session(session, "commit") ok = True return CreateArticle(article=new_article, ok=ok) @@ -252,7 +436,7 @@ class Mutation(graphene.ObjectType): create_article = CreateArticle.Field() query = """ - mutation ArticleCreator { + mutation { createArticle( headline: "My Article" reporterId: "1" @@ -279,7 +463,65 @@ class Mutation(graphene.ObjectType): } schema = graphene.Schema(query=Query, mutation=Mutation) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors result = to_std_dicts(result.data) assert result == expected + + +async def add_person_data(session): + bob = Employee(name="Bob", birth_date=date(1990, 1, 1), hire_date=date(2015, 1, 1)) + session.add(bob) + joe = Employee(name="Joe", birth_date=date(1980, 1, 1), hire_date=date(2010, 1, 1)) + session.add(joe) + jen = Employee(name="Jen", birth_date=date(1995, 1, 1), hire_date=date(2020, 1, 1)) + session.add(jen) + await eventually_await_session(session, "commit") + + +@pytest.mark.asyncio +async def test_interface_query_on_base_type(session_factory): + session = session_factory() + await add_person_data(session) + + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + class Query(graphene.ObjectType): + people = graphene.Field(graphene.List(PersonType)) + + async def resolve_people(self, _info): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Person))).all() + return session.query(Person).all() + + schema = graphene.Schema(query=Query, types=[PersonType, EmployeeType]) + result = await schema.execute_async( + """ + query { + people { + __typename + name + birthDate + ... on EmployeeType { + hireDate + } + } + } + """ + ) + + assert not result.errors + assert len(result.data["people"]) == 3 + assert result.data["people"][0]["__typename"] == "EmployeeType" + assert result.data["people"][0]["name"] == "Bob" + assert result.data["people"][0]["birthDate"] == "1990-01-01" + assert result.data["people"][0]["hireDate"] == "2015-01-01" diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index ec585d57..14c87f74 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,15 +1,24 @@ +import pytest +from sqlalchemy import select + import graphene +from graphene_sqlalchemy.tests.utils import eventually_await_session +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from ..types import SQLAlchemyObjectType from .models import HairKind, Pet, Reporter from .test_query import add_test_data, to_std_dicts +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession -def test_query_pet_kinds(session): - add_test_data(session) - class PetType(SQLAlchemyObjectType): +@pytest.mark.asyncio +async def test_query_pet_kinds(session, session_factory): + await add_test_data(session) + await eventually_await_session(session, "close") + class PetType(SQLAlchemyObjectType): class Meta: model = Pet @@ -20,19 +29,32 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - pets = graphene.List(PetType, kind=graphene.Argument( - PetType.enum_for_field('pet_kind'))) + pets = graphene.List( + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) - def resolve_pets(self, _info, kind): + async def resolve_pets(self, _info, kind): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).unique().all() query = session.query(Pet) if kind: - query = query.filter_by(pet_kind=kind) + query = query.filter_by(pet_kind=kind.value) return query query = """ @@ -58,36 +80,36 @@ def resolve_pets(self, _info, kind): } """ expected = { - 'reporter': { - 'firstName': 'John', - 'lastName': 'Doe', - 'email': None, - 'favoritePetKind': 'CAT', - 'pets': [{ - 'name': 'Garfield', - 'petKind': 'CAT' - }] + "reporter": { + "firstName": "John", + "lastName": "Doe", + "email": None, + "favoritePetKind": "CAT", + "pets": [{"name": "Garfield", "petKind": "CAT"}], }, - 'reporters': [{ - 'firstName': 'John', - 'favoritePetKind': 'CAT', - }, { - 'firstName': 'Jane', - 'favoritePetKind': 'DOG', - }], - 'pets': [{ - 'name': 'Lassie', - 'petKind': 'DOG' - }] + "reporters": [ + { + "firstName": "John", + "favoritePetKind": "CAT", + }, + { + "firstName": "Jane", + "favoritePetKind": "DOG", + }, + ], + "pets": [{"name": "Lassie", "petKind": "DOG"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors assert result.data == expected -def test_query_more_enums(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_more_enums(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -96,7 +118,10 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field(PetType) - def resolve_pet(self, _info): + async def resolve_pet(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Pet))).first() return session.query(Pet).first() query = """ @@ -110,14 +135,15 @@ def resolve_pet(self, _info): """ expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -125,13 +151,19 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field( - PetType, - kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) - def resolve_pet(self, info, kind=None): + async def resolve_pet(self, info, kind=None): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).first() query = session.query(Pet) if kind: - query = query.filter(Pet.pet_kind == kind) + query = query.filter(Pet.pet_kind == kind.value) return query.first() query = """ @@ -145,19 +177,24 @@ def resolve_pet(self, info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "CAT"}) + result = await schema.execute_async( + query, variables={"kind": "CAT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "DOG"}) + result = await schema.execute_async( + query, variables={"kind": "DOG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) assert result == expected -def test_py_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_py_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -169,7 +206,14 @@ class Query(graphene.ObjectType): kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), ) - def resolve_pet(self, _info, kind=None): + async def resolve_pet(self, _info, kind=None): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return ( + await session.scalars( + select(Pet).filter(Pet.hair_kind == HairKind(kind)) + ) + ).first() query = session.query(Pet) if kind: # enum arguments are expected to be strings, not PyEnums @@ -187,11 +231,15 @@ def resolve_pet(self, _info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "SHORT"}) + result = await schema.execute_async( + query, variables={"kind": "SHORT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "LONG"}) + result = await schema.execute_async( + query, variables={"kind": "LONG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) diff --git a/graphene_sqlalchemy/tests/test_reflected.py b/graphene_sqlalchemy/tests/test_reflected.py index 46e10de9..a3f6c4aa 100644 --- a/graphene_sqlalchemy/tests/test_reflected.py +++ b/graphene_sqlalchemy/tests/test_reflected.py @@ -1,4 +1,3 @@ - from graphene import ObjectType from ..registry import Registry diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 0403c4f0..e54f08b1 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,12 +1,13 @@ import pytest from sqlalchemy.types import Enum as SQLAlchemyEnum +import graphene from graphene import Enum as GrapheneEnum from ..registry import Registry from ..types import SQLAlchemyObjectType from ..utils import EnumValue -from .models import Pet +from .models import Pet, Reporter def test_register_object_type(): @@ -27,7 +28,7 @@ def test_register_incorrect_object_type(): class Spam: pass - re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + re_err = "Expected SQLAlchemyBase, but got: .*Spam" with pytest.raises(TypeError, match=re_err): reg.register(Spam) @@ -50,7 +51,7 @@ def test_register_orm_field_incorrect_types(): class Spam: pass - re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + re_err = "Expected SQLAlchemyBase, but got: .*Spam" with pytest.raises(TypeError, match=re_err): reg.register_orm_field(Spam, "name", Pet.name) @@ -126,3 +127,56 @@ class Meta: re_err = r"Expected Graphene Enum, but got: .*PetType.*" with pytest.raises(TypeError, match=re_err): reg.register_sort_enum(PetType, PetType) + + +def test_register_union(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + union_types = [PetType, ReporterType] + union = graphene.Union.create_type("ReporterPet", types=tuple(union_types)) + + reg.register_union_type(union, union_types) + + assert reg.get_union_for_object_types(union_types) == union + # Order should not matter + assert reg.get_union_for_object_types([ReporterType, PetType]) == union + + +def test_register_union_scalar(): + reg = Registry() + + union_types = [graphene.String, graphene.Int] + union = graphene.Union.create_type("StringInt", types=union_types) + + re_err = r"Expected Graphene ObjectType, but got: .*String.*" + with pytest.raises(TypeError, match=re_err): + reg.register_union_type(union, union_types) + + +def test_register_union_incorrect_types(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + union_types = [PetType, ReporterType] + union = PetType + + re_err = r"Expected graphene.Union, but got: .*PetType.*" + with pytest.raises(TypeError, match=re_err): + reg.register_union_type(union, union_types) diff --git a/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/tests/test_schema.py deleted file mode 100644 index 87739bdb..00000000 --- a/graphene_sqlalchemy/tests/test_schema.py +++ /dev/null @@ -1,50 +0,0 @@ -from py.test import raises - -from ..registry import Registry -from ..types import SQLAlchemyObjectType -from .models import Reporter - - -def test_should_raise_if_no_model(): - with raises(Exception) as excinfo: - - class Character1(SQLAlchemyObjectType): - pass - - assert "valid SQLAlchemy Model" in str(excinfo.value) - - -def test_should_raise_if_model_is_invalid(): - with raises(Exception) as excinfo: - - class Character2(SQLAlchemyObjectType): - class Meta: - model = 1 - - assert "valid SQLAlchemy Model" in str(excinfo.value) - - -def test_should_map_fields_correctly(): - class ReporterType2(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = Registry() - - assert list(ReporterType2._meta.fields.keys()) == [ - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - "pets", - "articles", - "favorite_article", - ] - - -def test_should_map_only_few_fields(): - class Reporter2(SQLAlchemyObjectType): - class Meta: - model = Reporter - only_fields = ("id", "email") - assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 1eb106da..bb530f2c 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -2,23 +2,24 @@ import sqlalchemy as sa from graphene import Argument, Enum, List, ObjectType, Schema -from graphene.relay import Connection, Node +from graphene.relay import Node from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from ..utils import to_type_name -from .models import Base, HairKind, Pet +from .models import Base, HairKind, KeyedModel, Pet from .test_query import to_std_dicts +from .utils import eventually_await_session -def add_pets(session): +async def add_pets(session): pets = [ Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), ] session.add_all(pets) - session.commit() + await eventually_await_session(session, "commit") def test_sort_enum(): @@ -40,6 +41,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -94,6 +97,8 @@ class Meta: "PET_KIND_DESC", "HAIR_KIND_ASC", "HAIR_KIND_DESC", + "LEGS_ASC", + "LEGS_DESC", ] @@ -134,6 +139,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -148,7 +155,7 @@ def test_sort_argument_with_excluded_fields_in_object_type(): class PetType(SQLAlchemyObjectType): class Meta: model = Pet - exclude_fields = ["hair_kind", "reporter_id"] + exclude_fields = ["hair_kind", "reporter_id", "legs"] sort_arg = PetType.sort_argument() sort_enum = sort_arg.type._of_type @@ -237,34 +244,33 @@ def get_symbol_name(column_name, sort_asc=True): "HairKindDown", "ReporterIdUp", "ReporterIdDown", + "LegsUp", + "LegsDown", ] assert sort_arg.default_value == ["IdUp"] -def test_sort_query(session): - add_pets(session) +@pytest.mark.asyncio +async def test_sort_query(session): + await add_pets(session) class PetNode(SQLAlchemyObjectType): class Meta: model = Pet interfaces = (Node,) - class PetConnection(Connection): - class Meta: - node = PetNode - class Query(ObjectType): - defaultSort = SQLAlchemyConnectionField(PetConnection) - nameSort = SQLAlchemyConnectionField(PetConnection) - multipleSort = SQLAlchemyConnectionField(PetConnection) - descSort = SQLAlchemyConnectionField(PetConnection) + defaultSort = SQLAlchemyConnectionField(PetNode.connection) + nameSort = SQLAlchemyConnectionField(PetNode.connection) + multipleSort = SQLAlchemyConnectionField(PetNode.connection) + descSort = SQLAlchemyConnectionField(PetNode.connection) singleColumnSort = SQLAlchemyConnectionField( - PetConnection, sort=Argument(PetNode.sort_enum()) + PetNode.connection, sort=Argument(PetNode.sort_enum()) ) noDefaultSort = SQLAlchemyConnectionField( - PetConnection, sort=PetNode.sort_argument(has_default=False) + PetNode.connection, sort=PetNode.sort_argument(has_default=False) ) - noSort = SQLAlchemyConnectionField(PetConnection, sort=None) + noSort = SQLAlchemyConnectionField(PetNode.connection, sort=None) query = """ query sortTest { @@ -340,7 +346,7 @@ def makeNodes(nodeList): } # yapf: disable schema = Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -356,9 +362,9 @@ def makeNodes(nodeList): } } """ - result = schema.execute(queryError, context_value={"session": session}) + result = await schema.execute_async(queryError, context_value={"session": session}) assert result.errors is not None - assert '"sort" has invalid value' in result.errors[0].message + assert "cannot represent non-enum value" in result.errors[0].message queryNoSort = """ query sortTest { @@ -379,7 +385,7 @@ def makeNodes(nodeList): } """ - result = schema.execute(queryNoSort, context_value={"session": session}) + result = await schema.execute_async(queryNoSort, context_value={"session": session}) assert not result.errors # TODO: SQLite usually returns the results ordered by primary key, # so we cannot test this way whether sorting actually happens or not. @@ -387,3 +393,32 @@ def makeNodes(nodeList): assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] ] + + +def test_sort_enum_from_key_issue_330(): + """ + Verifies that the sort enum name is generated from the column key instead of the name, + in case the column has an invalid enum name. See #330 + """ + + class KeyedType(SQLAlchemyObjectType): + class Meta: + model = KeyedModel + + sort_enum = KeyedType.sort_enum() + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "KeyedTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "REPORTER_NUMBER_ASC", + "REPORTER_NUMBER_DESC", + ] + assert ( + str(sort_enum.REPORTER_NUMBER_ASC.value.value) + == 'test330."% reporter_number" ASC' + ) + assert ( + str(sort_enum.REPORTER_NUMBER_DESC.value.value) + == 'test330."% reporter_number" DESC' + ) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index b76136fb..f25b0dc2 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,282 +1,860 @@ -from collections import OrderedDict +import re +from unittest import mock + +import pytest +import sqlalchemy.exc +import sqlalchemy.orm.exc +from graphql.pyutils import is_awaitable +from sqlalchemy import select + +from graphene import ( + Boolean, + DefaultGlobalIDType, + Dynamic, + Field, + Float, + GlobalID, + Int, + List, + Node, + NonNull, + ObjectType, + Schema, + String, +) +from graphene.relay import Connection + +from .. import utils +from ..converter import convert_sqlalchemy_composite +from ..fields import ( + SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, + createConnectionField, + registerConnectionFieldFactory, + unregisterConnectionFieldFactory, +) +from ..types import ( + ORMField, + SQLAlchemyInterface, + SQLAlchemyObjectType, + SQLAlchemyObjectTypeOptions, +) +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 +from .models import ( + Article, + CompositeFullName, + CompositePrimaryKeyTestModel, + Employee, + NonAbstractPerson, + Person, + Pet, + Reporter, +) +from .utils import eventually_await_session + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession + + +def test_should_raise_if_no_model(): + re_err = r"valid SQLAlchemy Model" + with pytest.raises(Exception, match=re_err): + + class Character1(SQLAlchemyObjectType): + pass + + +def test_should_raise_if_model_is_invalid(): + re_err = r"valid SQLAlchemy Model" + with pytest.raises(Exception, match=re_err): + + class Character(SQLAlchemyObjectType): + class Meta: + model = 1 + + +@pytest.mark.asyncio +async def test_sqlalchemy_node(session): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) -import six # noqa F401 -from promise import Promise + reporter_id_field = ReporterType._meta.fields["id"] + assert isinstance(reporter_id_field, GlobalID) -from graphene import (Connection, Field, Int, Interface, Node, ObjectType, - is_node) + reporter = Reporter() + session.add(reporter) + await eventually_await_session(session, "commit") + info = mock.Mock(context={"session": session}) + reporter_node = ReporterType.get_node(info, reporter.id) + if is_awaitable(reporter_node): + reporter_node = await reporter_node + assert reporter == reporter_node -from ..fields import (SQLAlchemyConnectionField, - UnsortedSQLAlchemyConnectionField, - registerConnectionFieldFactory, - unregisterConnectionFieldFactory) -from ..registry import Registry -from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions -from .models import Article, Reporter -registry = Registry() +def test_connection(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + assert issubclass(ReporterType.connection, Connection) -class Character(SQLAlchemyObjectType): - """Character description""" - class Meta: - model = Reporter - registry = registry +def test_sqlalchemy_default_fields(): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return String() + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) -class Human(SQLAlchemyObjectType): - """Human description""" + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) - pub_date = Int() + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Columns + "column_prop", # SQLAlchemy retuns column properties first + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + # Composite + "composite_prop", + # Hybrid + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + # Relationship + "pets", + "articles", + "favorite_article", + # AssociationProxy + "headlines", + ] + ) - class Meta: - model = Article - exclude_fields = ("id",) - registry = registry - interfaces = (Node,) + # column + first_name_field = ReporterType._meta.fields["first_name"] + assert first_name_field.type == String + assert first_name_field.description == "First name" + + # column_property + column_prop_field = ReporterType._meta.fields["column_prop"] + assert column_prop_field.type == Int + # "doc" is ignored by column_property + assert column_prop_field.description is None + + # composite + full_name_field = ReporterType._meta.fields["composite_prop"] + assert full_name_field.type == String + # "doc" is ignored by composite + assert full_name_field.description is None + + # hybrid_property + hybrid_prop = ReporterType._meta.fields["hybrid_prop"] + assert hybrid_prop.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop.description is None + + # hybrid_property_str + hybrid_prop_str = ReporterType._meta.fields["hybrid_prop_str"] + assert hybrid_prop_str.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop_str.description is None + + # hybrid_property_int + hybrid_prop_int = ReporterType._meta.fields["hybrid_prop_int"] + assert hybrid_prop_int.type == Int + # "doc" is ignored by hybrid_property + assert hybrid_prop_int.description is None + + # hybrid_property_float + hybrid_prop_float = ReporterType._meta.fields["hybrid_prop_float"] + assert hybrid_prop_float.type == Float + # "doc" is ignored by hybrid_property + assert hybrid_prop_float.description is None + + # hybrid_property_bool + hybrid_prop_bool = ReporterType._meta.fields["hybrid_prop_bool"] + assert hybrid_prop_bool.type == Boolean + # "doc" is ignored by hybrid_property + assert hybrid_prop_bool.description is None + + # hybrid_property_list + hybrid_prop_list = ReporterType._meta.fields["hybrid_prop_list"] + assert hybrid_prop_list.type == List(Int) + # "doc" is ignored by hybrid_property + assert hybrid_prop_list.description is None + + # hybrid_prop_with_doc + hybrid_prop_with_doc = ReporterType._meta.fields["hybrid_prop_with_doc"] + assert hybrid_prop_with_doc.type == String + # docstring is picked up from hybrid_prop_with_doc + assert hybrid_prop_with_doc.description == "Docstring test" + + # relationship + favorite_article_field = ReporterType._meta.fields["favorite_article"] + assert isinstance(favorite_article_field, Dynamic) + assert favorite_article_field.type().type == ArticleType + assert favorite_article_field.type().description is None + + # assocation proxy + assoc_field = ReporterType._meta.fields["headlines"] + assert isinstance(assoc_field, Dynamic) + assert isinstance(assoc_field.type().type, List) + assert assoc_field.type().type.of_type == String + + assoc_field = ArticleType._meta.fields["recommended_reads"] + assert isinstance(assoc_field, Dynamic) + assert assoc_field.type().type == ArticleType.connection + + +def test_sqlalchemy_override_fields(): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return String() + + class ReporterMixin(object): + # columns + first_name = ORMField(required=True) + last_name = ORMField(description="Overridden") + + class ReporterType(SQLAlchemyObjectType, ReporterMixin): + class Meta: + model = Reporter + interfaces = (Node,) + # columns + email = ORMField(deprecation_reason="Overridden") + email_v2 = ORMField(model_attr="email", type_=Int) -def test_sqlalchemy_interface(): - assert issubclass(Node, Interface) - assert issubclass(Node, Node) + # column_property + column_prop = ORMField(type_=String) + # composite + composite_prop = ORMField() -# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) -# def test_sqlalchemy_get_node(get): -# human = Human.get_node(1, None) -# get.assert_called_with(id=1) -# assert human.id == 1 + # hybrid_property + hybrid_prop_with_doc = ORMField(description="Overridden") + hybrid_prop = ORMField(description="Overridden") + # relationships + favorite_article = ORMField(description="Overridden") + articles = ORMField(deprecation_reason="Overridden") + pets = ORMField(description="Overridden") -def test_objecttype_registered(): - assert issubclass(Character, ObjectType) - assert Character._meta.model == Reporter - assert list(Character._meta.fields.keys()) == [ - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - "pets", - "articles", - "favorite_article", - ] + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + use_connection = False + + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Fields from ReporterMixin + "first_name", + "last_name", + # Fields from ReporterType + "email", + "email_v2", + "column_prop", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "favorite_article", + "articles", + "pets", + # Then the automatic SQLAlchemy fields + "id", + "favorite_pet_kind", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + "headlines", + ] + ) -# def test_sqlalchemynode_idfield(): -# idfield = Node._meta.fields_map['id'] -# assert isinstance(idfield, GlobalIDField) + first_name_field = ReporterType._meta.fields["first_name"] + assert isinstance(first_name_field.type, NonNull) + assert first_name_field.type.of_type == String + assert first_name_field.description == "First name" + assert first_name_field.deprecation_reason is None + + last_name_field = ReporterType._meta.fields["last_name"] + assert last_name_field.type == String + assert last_name_field.description == "Overridden" + assert last_name_field.deprecation_reason is None + + email_field = ReporterType._meta.fields["email"] + assert email_field.type == String + assert email_field.description == "Email" + assert email_field.deprecation_reason == "Overridden" + + email_field_v2 = ReporterType._meta.fields["email_v2"] + assert email_field_v2.type == Int + assert email_field_v2.description == "Email" + assert email_field_v2.deprecation_reason is None + + hybrid_prop_field = ReporterType._meta.fields["hybrid_prop"] + assert hybrid_prop_field.type == String + assert hybrid_prop_field.description == "Overridden" + assert hybrid_prop_field.deprecation_reason is None + + hybrid_prop_with_doc_field = ReporterType._meta.fields["hybrid_prop_with_doc"] + assert hybrid_prop_with_doc_field.type == String + assert hybrid_prop_with_doc_field.description == "Overridden" + assert hybrid_prop_with_doc_field.deprecation_reason is None + + column_prop_field_v2 = ReporterType._meta.fields["column_prop"] + assert column_prop_field_v2.type == String + assert column_prop_field_v2.description is None + assert column_prop_field_v2.deprecation_reason is None + + composite_prop_field = ReporterType._meta.fields["composite_prop"] + assert composite_prop_field.type == String + assert composite_prop_field.description is None + assert composite_prop_field.deprecation_reason is None + + favorite_article_field = ReporterType._meta.fields["favorite_article"] + assert isinstance(favorite_article_field, Dynamic) + assert favorite_article_field.type().type == ArticleType + assert favorite_article_field.type().description == "Overridden" + + articles_field = ReporterType._meta.fields["articles"] + assert isinstance(articles_field, Dynamic) + assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) + assert articles_field.type().deprecation_reason == "Overridden" + + pets_field = ReporterType._meta.fields["pets"] + assert isinstance(pets_field, Dynamic) + assert isinstance(pets_field.type().type, NonNull) + assert isinstance(pets_field.type().type.of_type, List) + assert isinstance(pets_field.type().type.of_type.of_type, NonNull) + assert pets_field.type().type.of_type.of_type.of_type == PetType + assert pets_field.type().description == "Overridden" + + +def test_invalid_model_attr(): + err_msg = ( + "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" + ) + with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter -# def test_node_idfield(): -# idfield = Human._meta.fields_map['id'] -# assert isinstance(idfield, GlobalIDField) + first_name = ORMField(model_attr="does_not_exist") -def test_node_replacedfield(): - idfield = Human._meta.fields["pub_date"] - assert isinstance(idfield, Field) - assert idfield.type == Int +def test_only_fields(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "last_name") + first_name = ORMField() # Takes precedence + last_name = ORMField() # Noop -def test_object_type(): - class Human(SQLAlchemyObjectType): - """Human description""" + assert list(ReporterType._meta.fields.keys()) == ["first_name", "last_name", "id"] - pub_date = Int() +def test_exclude_fields(): + class ReporterType(SQLAlchemyObjectType): class Meta: - model = Article - # exclude_fields = ('id', ) - registry = registry - interfaces = (Node,) + model = Reporter + exclude_fields = ("id", "first_name") + + first_name = ORMField() # Takes precedence + last_name = ORMField() # Noop + + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + "first_name", + "last_name", + "column_prop", + "email", + "favorite_pet_kind", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + "pets", + "articles", + "favorite_article", + "headlines", + ] + ) - assert issubclass(Human, ObjectType) - assert list(Human._meta.fields.keys()) == [ - "id", - "headline", - "pub_date", - "reporter_id", - "reporter", - ] - assert is_node(Human) + +def test_only_and_exclude_fields(): + re_err = r"'only_fields' and 'exclude_fields' cannot be both set" + with pytest.raises(Exception, match=re_err): + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "last_name") + exclude_fields = ("id", "last_name") + + +def test_sqlalchemy_redefine_field(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + first_name = Int() + + first_name_field = ReporterType._meta.fields["first_name"] + assert isinstance(first_name_field, Field) + assert first_name_field.type == Int + + +@pytest.mark.asyncio +async def test_resolvers(session): + """Test that the correct resolver functions are called""" + + reporter = Reporter( + first_name="first_name", + last_name="last_name", + email="email", + favorite_pet_kind="cat", + ) + session.add(reporter) + await eventually_await_session(session, "commit") + + class ReporterMixin(object): + def resolve_id(root, _info): + return "ID" + + class ReporterType(ReporterMixin, SQLAlchemyObjectType): + class Meta: + model = Reporter + + email = ORMField() + email_v2 = ORMField(model_attr="email") + favorite_pet_kind = Field(String) + favorite_pet_kind_v2 = Field(String) + + def resolve_last_name(root, _info): + return root.last_name.upper() + + def resolve_email_v2(root, _info): + return root.email + "_V2" + + def resolve_favorite_pet_kind_v2(root, _info): + return str(root.favorite_pet_kind) + "_V2" + + class Query(ObjectType): + reporter = Field(ReporterType) + + async def resolve_reporter(self, _info): + session = utils.get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() + return session.query(Reporter).first() + + schema = Schema(query=Query) + result = await schema.execute_async( + """ + query { + reporter { + id + firstName + lastName + email + emailV2 + favoritePetKind + favoritePetKindV2 + } + } + """, + context_value={"session": session}, + ) + + assert not result.errors + # Custom resolver on a base class + assert result.data["reporter"]["id"] == "ID" + # Default field + default resolver + assert result.data["reporter"]["firstName"] == "first_name" + # Default field + custom resolver + assert result.data["reporter"]["lastName"] == "LAST_NAME" + # ORMField + default resolver + assert result.data["reporter"]["email"] == "email" + # ORMField + custom resolver + assert result.data["reporter"]["emailV2"] == "email_V2" + # Field + default resolver + assert result.data["reporter"]["favoritePetKind"] == "cat" + # Field + custom resolver + assert result.data["reporter"]["favoritePetKindV2"] == "cat_V2" # Test Custom SQLAlchemyObjectType Implementation -class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): - class Meta: - abstract = True -class CustomCharacter(CustomSQLAlchemyObjectType): - """Character description""" +@pytest.mark.asyncio +async def test_composite_id_resolver(session): + """Test that the correct resolver functions are called""" + + composite_reporter = CompositePrimaryKeyTestModel( + first_name="graphql", last_name="foundation" + ) + + session.add(composite_reporter) + await eventually_await_session(session, "commit") + + class CompositePrimaryKeyTestModelType(SQLAlchemyObjectType): + class Meta: + model = CompositePrimaryKeyTestModel + interfaces = (Node,) - class Meta: - model = Reporter - registry = registry + class Query(ObjectType): + composite_reporter = Field(CompositePrimaryKeyTestModelType) + + async def resolve_composite_reporter(self, _info): + session = utils.get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return ( + (await session.scalars(select(CompositePrimaryKeyTestModel))) + .unique() + .first() + ) + return session.query(CompositePrimaryKeyTestModel).first() + + schema = Schema(query=Query) + result = await schema.execute_async( + """ + query { + compositeReporter { + id + firstName + lastName + } + } + """, + context_value={"session": session}, + ) + + assert not result.errors + assert result.data["compositeReporter"]["id"] == DefaultGlobalIDType.to_global_id( + CompositePrimaryKeyTestModelType, str(("graphql", "foundation")) + ) def test_custom_objecttype_registered(): - assert issubclass(CustomCharacter, ObjectType) - assert CustomCharacter._meta.model == Reporter - assert list(CustomCharacter._meta.fields.keys()) == [ - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - "pets", - "articles", - "favorite_article", - ] + class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): + class Meta: + abstract = True + + class CustomReporterType(CustomSQLAlchemyObjectType): + class Meta: + model = Reporter + + assert issubclass(CustomReporterType, ObjectType) + assert CustomReporterType._meta.model == Reporter + assert len(CustomReporterType._meta.fields) == 18 # Test Custom SQLAlchemyObjectType with Custom Options -class CustomOptions(SQLAlchemyObjectTypeOptions): - custom_option = None - custom_fields = None +def test_objecttype_with_custom_options(): + class CustomOptions(SQLAlchemyObjectTypeOptions): + custom_option = None + + class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, custom_option=None, **options): + _meta = CustomOptions(cls) + _meta.custom_option = custom_option + super( + SQLAlchemyObjectTypeWithCustomOptions, cls + ).__init_subclass_with_meta__(_meta=_meta, **options) + + class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): + class Meta: + model = Reporter + custom_option = "custom_option" + assert issubclass(ReporterWithCustomOptions, ObjectType) + assert ReporterWithCustomOptions._meta.model == Reporter + assert ReporterWithCustomOptions._meta.custom_option == "custom_option" -class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): - class Meta: - abstract = True - @classmethod - def __init_subclass_with_meta__( - cls, custom_option=None, custom_fields=None, **options +def test_interface_with_polymorphic_identity(): + with pytest.raises( + AssertionError, + match=re.escape( + 'PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")' + ), ): - _meta = CustomOptions(cls) - _meta.custom_option = custom_option - _meta.fields = custom_fields - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + class PersonType(SQLAlchemyInterface): + class Meta: + model = NonAbstractPerson -class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): - class Meta: - model = Reporter - custom_option = "custom_option" - custom_fields = OrderedDict([("custom_field", Field(Int()))]) +def test_interface_inherited_fields(): + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person -def test_objecttype_with_custom_options(): - assert issubclass(ReporterWithCustomOptions, ObjectType) - assert ReporterWithCustomOptions._meta.model == Reporter - assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ - "custom_field", + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + assert PersonType in EmployeeType._meta.interfaces + + name_field = EmployeeType._meta.fields["name"] + assert name_field.type == String + + # `type` should *not* be in this list because it's the polymorphic_on + # discriminator for Person + assert list(EmployeeType._meta.fields.keys()) == [ "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - "pets", - "articles", - "favorite_article", + "name", + "birth_date", + "hire_date", ] - assert ReporterWithCustomOptions._meta.custom_option == "custom_option" - assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int) -def test_promise_connection_resolver(): - class TestConnection(Connection): +def test_interface_type_field_orm_override(): + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + type = ORMField() + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + assert PersonType in EmployeeType._meta.interfaces + + name_field = EmployeeType._meta.fields["name"] + assert name_field.type == String + + # type should be in this list because we used ORMField + # to force its presence on the model + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted( + [ + "id", + "name", + "type", + "birth_date", + "hire_date", + ] + ) + + +def test_interface_custom_resolver(): + class PersonType(SQLAlchemyInterface): class Meta: - node = ReporterWithCustomOptions + model = Person - def resolver(_obj, _info): - return Promise.resolve([]) + custom_field = Field(String) - result = SQLAlchemyConnectionField.connection_resolver( - resolver, TestConnection, ReporterWithCustomOptions, None, None + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + assert PersonType in EmployeeType._meta.interfaces + + name_field = EmployeeType._meta.fields["name"] + assert name_field.type == String + + # type should be in this list because we used ORMField + # to force its presence on the model + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted( + [ + "id", + "name", + "custom_field", + "birth_date", + "hire_date", + ] ) - assert result is not None # Tests for connection_field_factory + class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): pass def test_default_connection_field_factory(): - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - registry = _registry interfaces = (Node,) class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), UnsortedSQLAlchemyConnectionField + ) -def test_register_connection_field_factory(): +def test_custom_connection_field_factory(): def test_connection_field_factory(relationship, registry): model = relationship.mapper.entity _type = registry.get_type_for_model(model) return _TestSQLAlchemyConnectionField(_type._meta.connection) - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - registry = _registry interfaces = (Node,) connection_field_factory = test_connection_field_factory class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_registerConnectionFieldFactory(): - registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + with pytest.warns(DeprecationWarning): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) - _registry = Registry() + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) - class ReporterType(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = _registry - interfaces = (Node,) - - class ArticleType(SQLAlchemyObjectType): - class Meta: - model = Article - registry = _registry - interfaces = (Node,) + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_unregisterConnectionFieldFactory(): - registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) - unregisterConnectionFieldFactory() + with pytest.warns(DeprecationWarning): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + unregisterConnectionFieldFactory() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + assert not isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) - _registry = Registry() - class ReporterType(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = _registry - interfaces = (Node,) +def test_deprecated_createConnectionField(): + with pytest.warns(DeprecationWarning): + createConnectionField(None) - class ArticleType(SQLAlchemyObjectType): - class Meta: - model = Article - registry = _registry - interfaces = (Node,) - assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) +@mock.patch(utils.__name__ + ".class_mapper") +def test_unique_errors_propagate(class_mapper_mock): + # Define unique error to detect + class UniqueError(Exception): + pass + + # Mock class_mapper effect + class_mapper_mock.side_effect = UniqueError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + + class ArticleOne(SQLAlchemyObjectType): + class Meta(object): + model = Article + + except UniqueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, UniqueError) + + +@mock.patch(utils.__name__ + ".class_mapper") +def test_argument_errors_propagate(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + + class ArticleTwo(SQLAlchemyObjectType): + class Meta(object): + model = Article + + except sqlalchemy.exc.ArgumentError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, sqlalchemy.exc.ArgumentError) + + +@mock.patch(utils.__name__ + ".class_mapper") +def test_unmapped_errors_reformat(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + + class ArticleThree(SQLAlchemyObjectType): + class Meta(object): + model = Article + + except ValueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, ValueError) + assert "You need to pass a valid SQLAlchemy Model" in str(error) diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index e13d919c..75328280 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -3,8 +3,14 @@ from graphene import Enum, List, ObjectType, Schema, String -from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, - to_enum_value_name, to_type_name) +from ..utils import ( + DummyImport, + get_session, + sort_argument_for_model, + sort_enum_for_model, + to_enum_value_name, + to_type_name, +) from .models import Base, Editor, Pet @@ -96,6 +102,12 @@ class MultiplePK(Base): with pytest.warns(DeprecationWarning): arg = sort_argument_for_model(MultiplePK) - assert set(arg.default_value) == set( - (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") - ) + assert set(arg.default_value) == { + MultiplePK.foo.name + "_asc", + MultiplePK.bar.name + "_asc", + } + + +def test_dummy_import(): + dummy_module = DummyImport() + assert dummy_module.foo == object diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py new file mode 100644 index 00000000..6e843316 --- /dev/null +++ b/graphene_sqlalchemy/tests/utils.py @@ -0,0 +1,37 @@ +import inspect +import re + +from sqlalchemy import select + +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 + + +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value + + +def remove_cache_miss_stat(message): + """Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4""" + # https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177 + return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) + + +def wrap_select_func(query): + # TODO remove this when we drop support for sqa < 2.0 + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + return select(query) + else: + return select([query]) + + +async def eventually_await_session(session, func, *args): + if inspect.iscoroutinefunction(getattr(session, func)): + await getattr(session, func)(*args) + else: + getattr(session, func)(*args) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index c20e8cfc..894ebfdb 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,102 +1,429 @@ +import inspect +import logging +import warnings from collections import OrderedDict +from functools import partial +from inspect import isawaitable +from typing import Any, Optional, Type, Union import sqlalchemy +from sqlalchemy.ext.associationproxy import AssociationProxy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.inspection import inspect as sqlalchemyinspect +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound -from graphene import Field +import graphene +from graphene import Dynamic, Field, InputField from graphene.relay import Connection, Node +from graphene.types.base import BaseType +from graphene.types.interface import Interface, InterfaceOptions from graphene.types.objecttype import ObjectType, ObjectTypeOptions +from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs - -from .converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_hybrid_method, - convert_sqlalchemy_relationship) -from .enums import (enum_for_field, sort_argument_for_object_type, - sort_enum_for_object_type) -from .fields import default_connection_field_factory +from graphene.utils.orderedtype import OrderedType + +from .converter import ( + convert_sqlalchemy_association_proxy, + convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship, +) +from .enums import ( + enum_for_field, + sort_argument_for_object_type, + sort_enum_for_object_type, +) +from .filters import BaseTypeFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry -from .utils import get_query, is_mapped_class, is_mapped_instance +from .resolvers import get_attr_resolver, get_custom_resolver +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_nullable_type, + get_query, + get_session, + is_mapped_class, + is_mapped_instance, +) + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession + +logger = logging.getLogger(__name__) + + +class ORMField(OrderedType): + def __init__( + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + create_filter=None, + filter_type: Optional[Type] = None, + _creation_counter=None, + **field_kwargs, + ): + """ + Use this to override fields automatically generated by SQLAlchemyObjectType. + Unless specified, options will default to SQLAlchemyObjectType usual behavior + for the given SQLAlchemy model property. + + Usage: + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel + + id = ORMField(type_=graphene.Int) + name = ORMField(required=True) + + -> MyType.id will be of type Int (vs ID). + -> MyType.name will be of type NonNull(String) (vs String). + + :param str model_attr: + Name of the SQLAlchemy model attribute used to resolve this field. + Default to the name of the attribute referencing the ORMField. + :param type_: + Default to the type mapping in converter.py. + :param str description: + Default to the `doc` attribute of the SQLAlchemy column property. + :param bool required: + Default to the opposite of the `nullable` attribute of the SQLAlchemy column property. + :param str description: + Same behavior as in graphene.Field. Defaults to None. + :param str deprecation_reason: + Same behavior as in graphene.Field. Defaults to None. + :param bool batching: + Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. + :param bool create_filter: + Create a filter for this field. Defaults to True. + :param Type filter_type: + Override for the filter of this field with a custom filter type. + Default behavior is to get a matching filter type for this field from the registry. + Create_filter needs to be true + :param int _creation_counter: + Same behavior as in graphene.Field. + """ + super(ORMField, self).__init__(_creation_counter=_creation_counter) + # The is only useful for documentation and auto-completion + common_kwargs = { + "model_attr": model_attr, + "type_": type_, + "required": required, + "description": description, + "deprecation_reason": deprecation_reason, + "create_filter": create_filter, + "filter_type": filter_type, + "batching": batching, + } + common_kwargs = { + kwarg: value for kwarg, value in common_kwargs.items() if value is not None + } + self.kwargs = field_kwargs + self.kwargs.update(common_kwargs) + + +def get_or_create_relationship_filter( + base_type: Type[BaseType], registry: Registry +) -> Type[RelationshipFilter]: + relationship_filter = registry.get_relationship_filter_for_base_type(base_type) + + if not relationship_filter: + try: + base_type_filter = registry.get_filter_for_base_type(base_type) + relationship_filter = RelationshipFilter.create_type( + f"{base_type.__name__}RelationshipFilter", + base_type_filter=base_type_filter, + model=base_type._meta.model, + ) + registry.register_relationship_filter_for_base_type( + base_type, relationship_filter + ) + except Exception as e: + print("e") + raise e + + return relationship_filter + + +def filter_field_from_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + type_, + registry: Registry, + model_attr: Any, + model_attr_name: str, +) -> Optional[graphene.InputField]: + # Field might be a SQLAlchemyObjectType, due to hybrid properties + if issubclass(type_, SQLAlchemyObjectType): + filter_class = registry.get_filter_for_base_type(type_) + # Enum Special Case + elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): + column = model_attr.columns[0] + model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None) + if not getattr(model_enum_type, "enum_class", None): + filter_class = registry.get_filter_for_sql_enum_type(type_) + else: + filter_class = registry.get_filter_for_py_enum_type(type_) + else: + filter_class = registry.get_filter_for_scalar_type(type_) + if not filter_class: + warnings.warn( + f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field." + ) + return None + return SQLAlchemyFilterInputField(filter_class, model_attr_name) + + +def resolve_dynamic_relationship_filter( + field: graphene.Dynamic, registry: Registry, model_attr_name: str +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # Resolve Dynamic Type + type_ = get_nullable_type(field.get_type()) + from graphene_sqlalchemy import SQLAlchemyConnectionField + + # Connections always result in list filters + if isinstance(type_, SQLAlchemyConnectionField): + inner_type = get_nullable_type(type_.type.Edge.node._type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + # Field relationships can either be a list or a single object + elif isinstance(type_, Field): + if isinstance(type_.type, graphene.List): + inner_type = get_nullable_type(type_.type.of_type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + else: + reg_res = registry.get_filter_for_base_type(type_.type) + else: + # Other dynamic type constellation are not yet supported, + # please open an issue with reproduction if you need them + reg_res = None + + if not reg_res: + warnings.warn( + f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field." + ) + return None + + return SQLAlchemyFilterInputField(reg_res, model_attr_name) + + +def filter_field_from_type_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], + model_attr: Any, + model_attr_name: str, +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # If a custom filter type was set for this field, use it here + if filter_type: + return SQLAlchemyFilterInputField(filter_type, model_attr_name) + elif issubclass(type(field), graphene.Scalar): + filter_class = registry.get_filter_for_scalar_type(type(field)) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) + # If the generated field is Dynamic, it is always a relationship + # (due to graphene-sqlalchemy's conversion mechanism). + elif isinstance(field, graphene.Dynamic): + return Dynamic( + partial( + resolve_dynamic_relationship_filter, field, registry, model_attr_name + ) + ) + # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them + elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): + # Pure lists are not yet supported + pass + elif isinstance(field._type, graphene.Dynamic): + # Fields with nested dynamic Dynamic are not yet supported + pass + # Order matters, this comes last as field._type == list also matches Field + elif isinstance(field, graphene.Field): + if inspect.isfunction(field._type) or isinstance(field._type, partial): + return Dynamic( + lambda: filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + ) + else: + return filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) -def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, connection_field_factory +def get_polymorphic_on(model): + """ + Check whether this model is a polymorphic type, and if so return the name + of the discriminator field (`polymorphic_on`), so that it won't be automatically + generated as an ORMField. + """ + if hasattr(model, "__mapper__") and model.__mapper__.polymorphic_on is not None: + polymorphic_on = model.__mapper__.polymorphic_on + if isinstance(polymorphic_on, sqlalchemy.Column): + return polymorphic_on.name + + +def construct_fields_and_filters( + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + create_filters, + connection_field_factory, ): - inspected_model = sqlalchemyinspect(model) - - fields = OrderedDict() - - for name, column in inspected_model.columns.items(): - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_column = convert_sqlalchemy_column(column, registry) - registry.register_orm_field(obj_type, name, column) - fields[name] = converted_column - - for name, composite in inspected_model.composites.items(): - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields + """ + Construct all the fields for a SQLAlchemyObjectType. + The main steps are: + - Gather all the relevant attributes from the SQLAlchemy model + - Gather all the ORM fields defined on the type + - Merge in overrides and build up all the fields + + :param SQLAlchemyObjectType obj_type: + :param model: the SQLAlchemy model + :param Registry registry: + :param tuple[string] only_fields: + :param tuple[string] exclude_fields: + :param bool batching: + :param bool create_filters: Enable filter generation for this type + :param function|None connection_field_factory: + :rtype: OrderedDict[str, graphene.Field] + """ + inspected_model = sqlalchemy.inspect(model) + # Gather all the relevant attributes from the SQLAlchemy model in order + all_model_attrs = OrderedDict( + inspected_model.column_attrs.items() + + inspected_model.composites.items() + + [ + (name, item) + for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property) or isinstance(item, AssociationProxy) + ] + + inspected_model.relationships.items() + ) + + # Filter out excluded fields + polymorphic_on = get_polymorphic_on(model) + auto_orm_field_names = [] + for attr_name, attr in all_model_attrs.items(): + if ( + (only_fields and attr_name not in only_fields) + or (attr_name in exclude_fields) + or attr_name == polymorphic_on + ): continue - converted_composite = convert_sqlalchemy_composite(composite, registry) - registry.register_orm_field(obj_type, name, composite) - fields[name] = converted_composite - - for hybrid_item in inspected_model.all_orm_descriptors: - - if type(hybrid_item) == hybrid_property: - name = hybrid_item.__name__ - - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - - converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) - registry.register_orm_field(obj_type, name, hybrid_item) - fields[name] = converted_hybrid_property - - # Get all the columns for the relationships on the model - for relationship in inspected_model.relationships: - is_not_in_only = only_fields and relationship.key not in only_fields - # is_already_created = relationship.key in options.fields - is_excluded = relationship.key in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields + auto_orm_field_names.append(attr_name) + + # Gather all the ORM fields defined on the type + custom_orm_fields_items = [ + (attn_name, attr) + for base in reversed(obj_type.__mro__) + for attn_name, attr in base.__dict__.items() + if isinstance(attr, ORMField) + ] + custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1]) + + # Set the model_attr if not set + for orm_field_name, orm_field in custom_orm_fields_items: + attr_name = orm_field.kwargs.get("model_attr", orm_field_name) + if attr_name not in all_model_attrs: + raise ValueError( + ("Cannot map ORMField to a model attribute.\n" "Field: '{}.{}'").format( + obj_type.__name__, + orm_field_name, + ) + ) + orm_field.kwargs["model_attr"] = attr_name + + # Merge automatic fields with custom ORM fields + orm_fields = OrderedDict(custom_orm_fields_items) + for orm_field_name in auto_orm_field_names: + if orm_field_name in orm_fields: continue - converted_relationship = convert_sqlalchemy_relationship( - relationship, registry, connection_field_factory + orm_fields[orm_field_name] = ORMField(model_attr=orm_field_name) + + # Build all the field dictionary + fields = OrderedDict() + filters = OrderedDict() + for orm_field_name, orm_field in orm_fields.items(): + filtering_enabled_for_field = orm_field.kwargs.pop( + "create_filter", create_filters + ) + filter_type = orm_field.kwargs.pop("filter_type", None) + attr_name = orm_field.kwargs.pop("model_attr") + attr = all_model_attrs[attr_name] + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( + obj_type, attr_name ) - name = relationship.key - registry.register_orm_field(obj_type, name, relationship) - fields[name] = converted_relationship - return fields + if isinstance(attr, ColumnProperty): + field = convert_sqlalchemy_column( + attr, registry, resolver, **orm_field.kwargs + ) + elif isinstance(attr, RelationshipProperty): + batching_ = orm_field.kwargs.pop("batching", batching) + field = convert_sqlalchemy_relationship( + attr, + obj_type, + connection_field_factory, + batching_, + orm_field_name, + **orm_field.kwargs, + ) + elif isinstance(attr, CompositeProperty): + if attr_name != orm_field_name or orm_field.kwargs: + # TODO Add a way to override composite property fields + raise ValueError( + "ORMField kwargs for composite fields must be empty. " + "Field: {}.{}".format(obj_type.__name__, orm_field_name) + ) + field = convert_sqlalchemy_composite(attr, registry, resolver) + elif isinstance(attr, hybrid_property): + field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) + elif isinstance(attr, AssociationProxy): + field = convert_sqlalchemy_association_proxy( + model, + attr, + obj_type, + registry, + connection_field_factory, + batching, + resolver, + **orm_field.kwargs, + ) + else: + raise Exception("Property type is not supported") # Should never happen + + registry.register_orm_field(obj_type, orm_field_name, attr) + fields[orm_field_name] = field + if filtering_enabled_for_field and not isinstance(attr, AssociationProxy): + # we don't support filtering on association proxies yet. + # Support will be patched in a future release of graphene-sqlalchemy + filters[orm_field_name] = filter_field_from_type_field( + field, registry, filter_type, attr, attr_name + ) + return fields, filters -class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): - model = None # type: sqlalchemy.Model - registry = None # type: sqlalchemy.Registry - connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] - id = None # type: str +class SQLAlchemyBase(BaseType): + """ + This class contains initialization code that is common to both ObjectTypes + and Interfaces. You typically don't need to use it directly. + """ -class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__( cls, @@ -110,13 +437,23 @@ def __init_subclass_with_meta__( use_connection=None, interfaces=(), id=None, - connection_field_factory=default_connection_field_factory, + batching=False, + connection_field_factory=None, _meta=None, - **options + create_filters=True, + **options, ): - assert is_mapped_class(model), ( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' - ).format(cls.__name__, model) + # We always want to bypass this hook unless we're defining a concrete + # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. + if not _meta: + return + + # Make sure model is a valid SQLAlchemy model + if not is_mapped_class(model): + raise ValueError( + "You need to pass a valid SQLAlchemy Model in " + '{}.Meta, received "{}".'.format(cls.__name__, model) + ) if not registry: registry = get_global_registry() @@ -126,21 +463,31 @@ def __init_subclass_with_meta__( 'Registry, received "{}".' ).format(cls.__name__, registry) + if only_fields and exclude_fields: + raise ValueError( + "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." + ) + + fields, filters = construct_fields_and_filters( + obj_type=cls, + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + batching=batching, + create_filters=create_filters, + connection_field_factory=connection_field_factory, + ) + sqla_fields = yank_fields_from_attrs( - construct_fields( - obj_type=cls, - model=model, - registry=registry, - only_fields=only_fields, - exclude_fields=exclude_fields, - connection_field_factory=connection_field_factory, - ), + fields, _as=Field, + sort=False, ) if use_connection is None and interfaces: use_connection = any( - (issubclass(interface, Node) for interface in interfaces) + issubclass(interface, Node) for interface in interfaces ) if use_connection and not connection: @@ -157,9 +504,6 @@ def __init_subclass_with_meta__( "The connection must be a Connection. Received {}" ).format(connection.__name__) - if not _meta: - _meta = SQLAlchemyObjectTypeOptions(cls) - _meta.model = model _meta.registry = registry @@ -168,10 +512,25 @@ def __init_subclass_with_meta__( else: _meta.fields = sqla_fields + # Save Generated filter class in Meta Class + if create_filters and not _meta.filter_class: + # Map graphene fields to filters + # TODO we might need to pass the ORMFields containing the SQLAlchemy models + # to the scalar filters here (to generate expressions from the model) + + filter_fields = yank_fields_from_attrs(filters, _as=InputField, sort=False) + + _meta.filter_class = BaseTypeFilter.create_type( + f"{cls.__name__}Filter", filter_fields=filter_fields, model=model + ) + registry.register_filter_for_base_type(cls, _meta.filter_class) + _meta.connection = connection _meta.id = id or "id" - super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( + cls.connection = connection # Public way to get the connection + + super(SQLAlchemyBase, cls).__init_subclass_with_meta__( _meta=_meta, interfaces=interfaces, **options ) @@ -182,6 +541,11 @@ def __init_subclass_with_meta__( def is_type_of(cls, root, info): if isinstance(root, cls): return True + if isawaitable(root): + raise Exception( + "Received coroutine instead of sql alchemy model. " + "You seem to use an async engine with synchronous schema execution" + ) if not is_mapped_instance(root): raise Exception(('Received incompatible instance "{}".').format(root)) return isinstance(root, cls._meta.model) @@ -193,6 +557,19 @@ def get_query(cls, info): @classmethod def get_node(cls, info, id): + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + try: + return cls.get_query(info).get(id) + except NoResultFound: + return None + + session = get_session(info.context) + if isinstance(session, AsyncSession): + + async def get_result() -> Any: + return await session.get(cls._meta.model, id) + + return get_result() try: return cls.get_query(info).get(id) except NoResultFound: @@ -201,12 +578,126 @@ def get_node(cls, info, id): def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) - return tuple(keys) if len(keys) > 1 else keys[0] + return str(tuple(keys)) if len(keys) > 1 else keys[0] @classmethod def enum_for_field(cls, field_name): return enum_for_field(cls, field_name) + @classmethod + def get_filter_argument(cls): + if cls._meta.filter_class: + return graphene.Argument(cls._meta.filter_class) + return None + sort_enum = classmethod(sort_enum_for_object_type) sort_argument = classmethod(sort_argument_for_object_type) + + +class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): + model = None # type: sqlalchemy.Model + registry = None # type: sqlalchemy.Registry + connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] + id = None # type: str + filter_class: Type[BaseTypeFilter] = None + + +class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): + """ + This type represents the GraphQL ObjectType. It reflects on the + given SQLAlchemy model, and automatically generates an ObjectType + using the column and relationship information defined there. + + Usage: + + .. code-block:: python + + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String()) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel + """ + + @classmethod + def __init_subclass_with_meta__(cls, _meta=None, **options): + if not _meta: + _meta = SQLAlchemyObjectTypeOptions(cls) + + super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + +class SQLAlchemyInterfaceOptions(InterfaceOptions): + model = None # type: sqlalchemy.Model + registry = None # type: sqlalchemy.Registry + connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] + id = None # type: str + filter_class: Type[BaseTypeFilter] = None + + +class SQLAlchemyInterface(SQLAlchemyBase, Interface): + """ + This type represents the GraphQL Interface. It reflects on the + given SQLAlchemy model, and automatically generates an Interface + using the column and relationship information defined there. This + is used to construct interface relationships based on polymorphic + inheritance hierarchies in SQLAlchemy. + + Please note that by default, the "polymorphic_on" column is *not* + generated as a field on types that use polymorphic inheritance, as + this is considered an implentation detail. The idiomatic way to + retrieve the concrete GraphQL type of an object is to query for the + `__typename` field. + + Usage (using joined table inheritance): + + .. code-block:: python + + class MyBaseModel(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + + __mapper_args__ = { + "polymorphic_on": type, + } + + class MyChildModel(Base): + date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "child", + } + + class MyBaseType(SQLAlchemyInterface): + class Meta: + model = MyBaseModel + + class MyChildType(SQLAlchemyObjectType): + class Meta: + model = MyChildModel + interfaces = (MyBaseType,) + """ + + @classmethod + def __init_subclass_with_meta__(cls, _meta=None, **options): + if not _meta: + _meta = SQLAlchemyInterfaceOptions(cls) + + super(SQLAlchemyInterface, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + # make sure that the model doesn't have a polymorphic_identity defined + if hasattr(_meta.model, "__mapper__"): + polymorphic_identity = _meta.model.__mapper__.polymorphic_identity + assert ( + polymorphic_identity is None + ), '{}: An interface cannot map to a concrete type (polymorphic_identity is "{}")'.format( + cls.__name__, polymorphic_identity + ) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 7139eefc..17d774d2 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,10 +1,49 @@ import re +import typing import warnings +from collections import OrderedDict +from functools import _c3_mro +from importlib.metadata import version as get_version +from typing import Any, Callable, Dict, Optional +from packaging import version +from sqlalchemy import select from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene import NonNull + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type + + +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return version.parse(get_version("SQLAlchemy")) < version.parse(version_string) + + +def is_graphene_version_less_than(version_string): # pragma: no cover + """Check the installed graphene version""" + return version.parse(get_version("graphene")) < version.parse(version_string) + + +SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False + +if not is_sqlalchemy_version_less_than("1.4"): # pragma: no cover + from sqlalchemy.ext.asyncio import AsyncSession + + SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True + + +SQL_VERSION_HIGHER_EQUAL_THAN_2 = False + +if not is_sqlalchemy_version_less_than("2.0.0b1"): # pragma: no cover + SQL_VERSION_HIGHER_EQUAL_THAN_2 = True + def get_session(context): return context.get("session") @@ -19,6 +58,8 @@ def get_query(model, context): "A query in the model Base or a session in the schema is required for querying.\n" "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return select(model) query = session.query(model) return query @@ -26,7 +67,13 @@ def get_query(model, context): def is_mapped_class(cls): try: class_mapper(cls) - except (ArgumentError, UnmappedClassError): + except ArgumentError as error: + # Only handle ArgumentErrors for non-class objects + if "Class object expected" in str(error): + return False + raise + except UnmappedClassError: + # Unmapped classes return false return False else: return True @@ -80,7 +127,6 @@ def _deprecated_default_symbol_name(column_name, sort_asc): def _deprecated_object_type_for_model(cls, name): - try: return _deprecated_object_type_cache[cls, name] except KeyError: @@ -130,6 +176,7 @@ def sort_argument_for_model(cls, has_default=True): ) from graphene import Argument, List + from .enums import sort_enum_for_object_type enum = sort_enum_for_object_type( @@ -140,3 +187,90 @@ def sort_argument_for_model(cls, has_default=True): enum.default = None return Argument(List(enum), default_value=enum.default) + + +class singledispatchbymatchfunction: + """ + Inspired by @singledispatch, this is a variant that works using a matcher function + instead of relying on the type of the first argument. + The register method can be used to register a new matcher, which is passed as the first argument: + """ + + def __init__(self, default: Callable): + self.registry: Dict[Callable, Callable] = OrderedDict() + self.default = default + + def __call__(self, *args, **kwargs): + matched_arg = args[0] + try: + mro = _c3_mro(matched_arg) + except Exception: + # In case of tuples or similar types, we can't use the MRO. + # Fall back to just matching the original argument. + mro = [matched_arg] + + for cls in mro: + for matcher_function, final_method in self.registry.items(): + # Register order is important. First one that matches, runs. + if matcher_function(cls): + return final_method(*args, **kwargs) + + # No match, using default. + return self.default(*args, **kwargs) + + def register(self, matcher_function: Callable[[Any], bool], func=None): + if func is None: + return lambda f: self.register(matcher_function, f) + self.registry[matcher_function] = func + return func + + +def column_type_eq(value: Any) -> Callable[[Any], bool]: + """A simple function that makes the equality based matcher functions for + SingleDispatchByMatchFunction prettier""" + return lambda x: (x == value) + + +def safe_isinstance(cls): + def safe_isinstance_checker(arg): + try: + return isinstance(arg, cls) + except TypeError: + pass + + return safe_isinstance_checker + + +def safe_issubclass(cls): + def safe_issubclass_checker(arg): + try: + return issubclass(arg, cls) + except TypeError: + pass + + return safe_issubclass_checker + + +def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: + from graphene_sqlalchemy.registry import get_global_registry + + try: + return next( + filter( + lambda x: x.__name__ == model_name, + list(get_global_registry()._registry.keys()), + ) + ) + except StopIteration: + pass + + +def is_list(x): + return getattr(x, "__origin__", None) in [list, typing.List] + + +class DummyImport: + """The dummy module returns 'object' for a query for any member""" + + def __getattr__(self, name): + return object diff --git a/setup.cfg b/setup.cfg index 046db9dc..e479585c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,14 +2,16 @@ test=pytest [flake8] -exclude = setup.py,docs/*,examples/*,tests +ignore = E203,W503 +exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs,setup.py,docs/*,examples/*,tests max-line-length = 120 [isort] +profile = black no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=app,database,flask,models,nameko,promise,py,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=aiodataloader,app,database,flask,models,nameko,pkg_resources,promise,pytest,schema,setuptools,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER skip_glob=examples/nameko_sqlalchemy diff --git a/setup.py b/setup.py index 66704b28..33eabcb6 100644 --- a/setup.py +++ b/setup.py @@ -13,30 +13,34 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene - "graphene>=2.1.3,<3", - # Tests fail with 1.0.19 - "SQLAlchemy>=1.1,<2", - "six>=1.10.0,<2", - "singledispatch>=3.4.0.3,<4", + "graphene>=3.0.0b7", + "promise>=2.3", + "SQLAlchemy>=1.1", + "aiodataloader>=0.2.0,<1.0", + "packaging>=23.0", ] -try: - import enum -except ImportError: # Python < 2.7 and Python 3.3 - requirements.append("enum34 >= 1.1.6") tests_require = [ - "pytest==4.3.1", - "mock==2.0.0", - "pytest-cov==2.6.1", - "sqlalchemy_utils==0.33.9", + "pytest>=6.2.0,<7.0", + "pytest-asyncio>=0.18.3", + "pytest-cov>=2.11.0,<3.0", + "sqlalchemy_utils>=0.37.0,<1.0", + "pytest-benchmark>=3.4.0,<4.0", + "aiosqlite>=0.17.0", + "nest-asyncio", + "greenlet", ] setup( name="graphene-sqlalchemy", version=version, description="Graphene SQLAlchemy integration", - long_description=open("README.rst").read(), + long_description=open("README.md").read(), + long_description_content_type="text/markdown", url="https://github.com/graphql-python/graphene-sqlalchemy", + project_urls={ + "Documentation": "https://docs.graphene-python.org/projects/sqlalchemy/en/latest", + }, author="Syrus Akbary", author_email="me@syrusakbary.com", license="MIT", @@ -44,24 +48,22 @@ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.3", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: PyPy", ], - keywords="api graphql protocol rest relay graphene", + keywords="api graphql protocol rest relay graphene sqlalchemy", packages=find_packages(exclude=["tests"]), install_requires=requirements, extras_require={ "dev": [ "tox==3.7.0", # Should be kept in sync with tox.ini - "coveralls==1.7.0", - "pre-commit==1.14.4", + "pre-commit==2.19", + "flake8==4.0.0", ], "test": tests_require, }, diff --git a/tox.ini b/tox.ini index e55f7d9b..6ec4699e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,20 +1,45 @@ [tox] -envlist = pre-commit,py{27,34,35,36,37}-sql{11,12,13} +envlist = pre-commit,py{39,310,311,312,313}-sql{12,13,14,20} skipsdist = true minversion = 3.7.0 +[gh-actions] +python = + 3.9: py39 + 3.10: py310 + 3.11: py311 + 3.12: py312 + 3.13: py313 + +[gh-actions:env] +SQLALCHEMY = + 1.2: sql12 + 1.3: sql13 + 1.4: sql14 + 2.0: sql20 + [testenv] +passenv = GITHUB_* deps = .[test] - sql11: sqlalchemy>=1.1,<1.2 sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 + sql14: sqlalchemy>=1.4,<1.5 + sql20: sqlalchemy>=2.0.0b3 +setenv = + SQLALCHEMY_WARN_20 = 1 commands = - pytest graphene_sqlalchemy --cov=graphene_sqlalchemy {posargs} + python -W always -m pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} [testenv:pre-commit] -basepython=python3.7 +basepython=python3.10 deps = .[dev] commands = pre-commit {posargs:run --all-files} + +[testenv:flake8] +basepython = python3.10 +deps = -e.[dev] +commands = + flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120