From d83ab3161b801bea3e1dd64ff93a5f0c4131d499 Mon Sep 17 00:00:00 2001 From: adehad <26027314+adehad@users.noreply.github.com> Date: Sun, 4 Jun 2023 12:42:34 +0100 Subject: [PATCH 1/6] add a basic pre-commit config --- .pre-commit-config.yaml | 52 ++++++++++++++++++++++++++++++++++++++++ Makefile | 6 ++++- pyproject.toml | 53 +++++++++++++++++++++++++++++++++++++++++ tests/ruff.toml | 1 + 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 .pre-commit-config.yaml create mode 100644 tests/ruff.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..9ae6daca --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,52 @@ +--- +repos: + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.0.275" + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: mixed-line-ending + - id: check-byte-order-marker + - id: check-executables-have-shebangs + - id: check-merge-conflict + - id: check-symlinks + - id: check-vcs-permalinks + - id: debug-statements + - id: check-yaml + files: .*\.(yaml|yml)$ + - repo: https://github.com/codespell-project/codespell + rev: v2.2.4 + hooks: + - id: codespell + name: codespell + description: Checks for common misspellings in text files. + entry: codespell + language: python + types: [text] + args: [] + require_serial: false + additional_dependencies: [] + - repo: https://github.com/adrienverge/yamllint + rev: v1.31.0 + hooks: + - id: yamllint + files: \.(yaml|yml)$ +# - repo: https://github.com/pre-commit/mirrors-mypy +# rev: v1.3.0 +# hooks: +# - id: mypy +# additional_dependencies: +# - types-requests +# - types-pkg_resources +# args: +# [--no-strict-optional, --ignore-missing-imports, --show-error-codes] diff --git a/Makefile b/Makefile index 4244be36..d75e6e62 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: compile debug test quicktest clean all gen-errors gen-types _touch +.PHONY: compile debug lint test quicktest clean all gen-errors gen-types _touch PYTHON ?= python @@ -51,6 +51,10 @@ debug: _touch EDGEDB_DEBUG=1 $(PYTHON) setup.py build_ext --inplace +lint: + $(PYTHON) -m pip install pre-commit + pre-commit run --all-files + test: PYTHONASYNCIODEBUG=1 $(PYTHON) setup.py test $(PYTHON) setup.py test diff --git a/pyproject.toml b/pyproject.toml index fed528d4..75c7e385 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,56 @@ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" + + +######################################################################################## +# External Tool Config +######################################################################################## +[tool.mypy] +python_version = 3.8 +warn_unused_configs = true +namespace_packages = true +# plugins = "pydantic.mypy" + +[tool.ruff] +select = [ + "E", # pydocstyle + "W", # pydocstyle + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "D", # docstrings + "RUF", # ruff +] +ignore = [ + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D107", # Missing docstring in `__init__` + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + # + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` +] + +# Same as Black. +line-length = 88 + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +# Assume Python 3.8. (minimum supported) +target-version = "py38" + +# The source code paths to consider, e.g., when resolving first- vs. third-party imports +src = ["edgedb", "tests"] + +[tool.ruff.isort] +known-first-party = ["edgedb", "tests"] +required-imports = ["from __future__ import annotations"] + +[tool.ruff.pydocstyle] +# Use Google-style docstrings. +convention = "google" diff --git a/tests/ruff.toml b/tests/ruff.toml new file mode 100644 index 00000000..7298a65b --- /dev/null +++ b/tests/ruff.toml @@ -0,0 +1 @@ +ignore = ["D101", "D102"] From a51437729254f8bc195f3808a353cdc74e54d336 Mon Sep 17 00:00:00 2001 From: adehad <26027314+adehad@users.noreply.github.com> Date: Sat, 1 Jul 2023 22:50:51 +0100 Subject: [PATCH 2/6] use 79 for line length --- pyproject.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 75c7e385..7c9c0bb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,11 @@ warn_unused_configs = true namespace_packages = true # plugins = "pydantic.mypy" + +[tool.black] +line-length = 79 + + [tool.ruff] select = [ "E", # pydocstyle @@ -36,7 +41,7 @@ ignore = [ ] # Same as Black. -line-length = 88 +line-length = 79 # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" From 8dabefdec69323ea0c3d390e525f1fa45796f41d Mon Sep 17 00:00:00 2001 From: adehad <26027314+adehad@users.noreply.github.com> Date: Sat, 1 Jul 2023 23:06:07 +0100 Subject: [PATCH 3/6] run autformatter and linters --- docs/README.md | 2 +- docs/api/codegen.rst | 1 - docs/conf.py | 97 ++-- docs/index.rst | 2 +- edgedb/__init__.py | 211 ++++---- edgedb/_taskgroup.py | 67 +-- edgedb/_testbase.py | 208 +++---- edgedb/_version.py | 3 +- edgedb/abstract.py | 263 ++++----- edgedb/asyncio_client.py | 95 ++-- edgedb/base_client.py | 122 ++--- edgedb/blocking_client.py | 75 +-- edgedb/codegen/__main__.py | 2 +- edgedb/codegen/cli.py | 6 +- edgedb/codegen/generator.py | 25 +- edgedb/color.py | 24 +- edgedb/con_utils.py | 558 ++++++++++--------- edgedb/connresource.py | 18 +- edgedb/credentials.py | 61 ++- edgedb/datatypes/range.py | 38 +- edgedb/describe.py | 13 +- edgedb/enums.py | 43 +- edgedb/errors/__init__.py | 195 ++++--- edgedb/errors/_base.py | 59 +- edgedb/errors/tags.py | 18 +- edgedb/introspect.py | 22 +- edgedb/options.py | 66 ++- edgedb/platform.py | 4 + edgedb/scram/__init__.py | 238 ++++---- edgedb/scram/saslprep.py | 31 +- edgedb/transaction.py | 80 +-- setup.py | 236 ++++---- tests/__init__.py | 7 +- tests/bench_uuid.py | 45 +- tests/codegen/test-project1/linked | 2 +- tests/datatypes/test_datatypes.py | 675 +++++++++-------------- tests/datatypes/test_uuid.py | 85 +-- tests/test_async_query.py | 844 ++++++++++++++++------------- tests/test_async_retry.py | 84 +-- tests/test_async_tx.py | 26 +- tests/test_asyncio_client.py | 33 +- tests/test_blocking_client.py | 30 +- tests/test_codegen.py | 8 +- tests/test_con_utils.py | 527 +++++++++--------- tests/test_connect.py | 39 +- tests/test_credentials.py | 94 ++-- tests/test_datetime.py | 41 +- tests/test_enum.py | 37 +- tests/test_errors.py | 15 +- tests/test_globals.py | 47 +- tests/test_memory.py | 24 +- tests/test_proto.py | 45 +- tests/test_scram.py | 51 +- tests/test_sourcecode.py | 15 +- tests/test_sync_query.py | 696 ++++++++++++------------ tests/test_sync_retry.py | 87 +-- tests/test_sync_tx.py | 26 +- tools/gen_init.py | 39 +- 58 files changed, 3339 insertions(+), 3166 deletions(-) diff --git a/docs/README.md b/docs/README.md index 19f81a9c..6b106020 100644 --- a/docs/README.md +++ b/docs/README.md @@ -39,4 +39,4 @@ Test and check the links found in the documentation: ``` $ make linkcheck -``` \ No newline at end of file +``` diff --git a/docs/api/codegen.rst b/docs/api/codegen.rst index c62abccf..7616ff5d 100644 --- a/docs/api/codegen.rst +++ b/docs/api/codegen.rst @@ -93,4 +93,3 @@ The ``edgedb-py`` command supports the same set of :ref:`connection options --password-from-stdin --tls-ca-file --tls-security - diff --git a/docs/conf.py b/docs/conf.py index 013c9016..3426c51f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,33 +1,37 @@ #!/usr/bin/env python3 +from __future__ import annotations -import alabaster import os import sys -sys.path.insert(0, os.path.abspath('..')) +import alabaster + +sys.path.insert(0, os.path.abspath("..")) -version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), - 'edgedb', '_version.py') +version_file = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "edgedb", "_version.py" +) -with open(version_file, 'r') as f: +with open(version_file) as f: for line in f: - if line.startswith('__version__ ='): - _, _, version = line.partition('=') + if line.startswith("__version__ ="): + _, _, version = line.partition("=") version = version.strip(" \n'\"") break else: raise RuntimeError( - 'unable to read the version from edgedb/_version.py') + "unable to read the version from edgedb/_version.py" + ) # -- General configuration ------------------------------------------------ extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx.ext.intersphinx', - 'sphinxcontrib.asyncio', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx.ext.intersphinx", + "sphinxcontrib.asyncio", ] # This is done on purpose: multiple different documentations with @@ -38,41 +42,41 @@ add_module_names = False -templates_path = ['_templates'] -source_suffix = '.rst' -master_doc = 'index' -project = 'edgedb' -copyright = '2018-present MagicStack Inc. and the EdgeDB authors.' -author = 'MagicStack Inc. and the EdgeDB authors' +templates_path = ["_templates"] +source_suffix = ".rst" +master_doc = "index" +project = "edgedb" +copyright = "2018-present MagicStack Inc. and the EdgeDB authors." +author = "MagicStack Inc. and the EdgeDB authors" release = version language = None -exclude_patterns = ['_build'] -pygments_style = 'sphinx' +exclude_patterns = ["_build"] +pygments_style = "sphinx" todo_include_todos = False -suppress_warnings = ['image.nonlocal_uri'] +suppress_warnings = ["image.nonlocal_uri"] # -- Options for HTML output ---------------------------------------------- -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_path = [alabaster.get_path()] -html_title = 'EdgeDB Python Driver Documentation' -html_short_title = 'edgedb' -html_static_path = ['_static'] +html_title = "EdgeDB Python Driver Documentation" +html_short_title = "edgedb" +html_static_path = ["_static"] html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', + "**": [ + "about.html", + "navigation.html", ] } html_show_sourcelink = False html_show_sphinx = False html_show_copyright = True html_context = { - 'css_files': [ - '_static/theme_overrides.css', + "css_files": [ + "_static/theme_overrides.css", ], } -htmlhelp_basename = 'edgedbdoc' +htmlhelp_basename = "edgedbdoc" # -- Options for LaTeX output --------------------------------------------- @@ -80,28 +84,37 @@ latex_elements = {} latex_documents = [ - (master_doc, 'edgedb.tex', 'EdgeDB Python Driver Documentation', - author, 'manual'), + ( + master_doc, + "edgedb.tex", + "EdgeDB Python Driver Documentation", + author, + "manual", + ), ] # -- Options for manual page output --------------------------------------- man_pages = [ - (master_doc, 'edgedb', 'EdgeDB Python Driver Documentation', - [author], 1) + (master_doc, "edgedb", "EdgeDB Python Driver Documentation", [author], 1) ] # -- Options for Texinfo output ------------------------------------------- texinfo_documents = [ - (master_doc, 'edgedb', 'EdgeDB Python Driver Documentation', - author, 'edgedb', - 'Official EdgeDB Python Driver', - 'Miscellaneous'), + ( + master_doc, + "edgedb", + "EdgeDB Python Driver Documentation", + author, + "edgedb", + "Official EdgeDB Python Driver", + "Miscellaneous", + ), ] # -- Options for intersphinx ---------------------------------------------- -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} diff --git a/docs/index.rst b/docs/index.rst index 1a647fb6..7fb093b2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,7 +35,7 @@ and :ref:`asyncio ` implementations. * :ref:`edgedb-python-codegen` Python code generation command-line tool documentation. - + * :ref:`edgedb-python-advanced` Advanced usages of the state and optional customization. diff --git a/edgedb/__init__.py b/edgedb/__init__.py index bf565921..244af9d7 100644 --- a/edgedb/__init__.py +++ b/edgedb/__init__.py @@ -22,19 +22,24 @@ from ._version import __version__ from edgedb.datatypes.datatypes import ( - Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory + Tuple, + NamedTuple, + EnumValue, + RelativeDuration, + DateDuration, + ConfigMemory, ) from edgedb.datatypes.datatypes import Set, Object, Array, Link, LinkSet from edgedb.datatypes.range import Range from .abstract import ( - Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor, + Executor, + AsyncIOExecutor, + ReadOnlyExecutor, + AsyncIOReadOnlyExecutor, ) -from .asyncio_client import ( - create_async_client, - AsyncIOClient -) +from .asyncio_client import create_async_client, AsyncIOClient from .blocking_client import create_client, Client from .enums import Cardinality, ElementKind @@ -179,100 +184,102 @@ InternalClientError, ) -__all__.extend([ - "InternalServerError", - "UnsupportedFeatureError", - "ProtocolError", - "BinaryProtocolError", - "UnsupportedProtocolVersionError", - "TypeSpecNotFoundError", - "UnexpectedMessageError", - "InputDataError", - "ParameterTypeMismatchError", - "StateMismatchError", - "ResultCardinalityMismatchError", - "CapabilityError", - "UnsupportedCapabilityError", - "DisabledCapabilityError", - "QueryError", - "InvalidSyntaxError", - "EdgeQLSyntaxError", - "SchemaSyntaxError", - "GraphQLSyntaxError", - "InvalidTypeError", - "InvalidTargetError", - "InvalidLinkTargetError", - "InvalidPropertyTargetError", - "InvalidReferenceError", - "UnknownModuleError", - "UnknownLinkError", - "UnknownPropertyError", - "UnknownUserError", - "UnknownDatabaseError", - "UnknownParameterError", - "SchemaError", - "SchemaDefinitionError", - "InvalidDefinitionError", - "InvalidModuleDefinitionError", - "InvalidLinkDefinitionError", - "InvalidPropertyDefinitionError", - "InvalidUserDefinitionError", - "InvalidDatabaseDefinitionError", - "InvalidOperatorDefinitionError", - "InvalidAliasDefinitionError", - "InvalidFunctionDefinitionError", - "InvalidConstraintDefinitionError", - "InvalidCastDefinitionError", - "DuplicateDefinitionError", - "DuplicateModuleDefinitionError", - "DuplicateLinkDefinitionError", - "DuplicatePropertyDefinitionError", - "DuplicateUserDefinitionError", - "DuplicateDatabaseDefinitionError", - "DuplicateOperatorDefinitionError", - "DuplicateViewDefinitionError", - "DuplicateFunctionDefinitionError", - "DuplicateConstraintDefinitionError", - "DuplicateCastDefinitionError", - "SessionTimeoutError", - "IdleSessionTimeoutError", - "QueryTimeoutError", - "TransactionTimeoutError", - "IdleTransactionTimeoutError", - "ExecutionError", - "InvalidValueError", - "DivisionByZeroError", - "NumericOutOfRangeError", - "AccessPolicyError", - "IntegrityError", - "ConstraintViolationError", - "CardinalityViolationError", - "MissingRequiredError", - "TransactionError", - "TransactionConflictError", - "TransactionSerializationError", - "TransactionDeadlockError", - "ConfigurationError", - "AccessError", - "AuthenticationError", - "AvailabilityError", - "BackendUnavailableError", - "BackendError", - "UnsupportedBackendFeatureError", - "LogMessage", - "WarningMessage", - "ClientError", - "ClientConnectionError", - "ClientConnectionFailedError", - "ClientConnectionFailedTemporarilyError", - "ClientConnectionTimeoutError", - "ClientConnectionClosedError", - "InterfaceError", - "QueryArgumentError", - "MissingArgumentError", - "UnknownArgumentError", - "InvalidArgumentError", - "NoDataError", - "InternalClientError", -]) +__all__.extend( + [ + "InternalServerError", + "UnsupportedFeatureError", + "ProtocolError", + "BinaryProtocolError", + "UnsupportedProtocolVersionError", + "TypeSpecNotFoundError", + "UnexpectedMessageError", + "InputDataError", + "ParameterTypeMismatchError", + "StateMismatchError", + "ResultCardinalityMismatchError", + "CapabilityError", + "UnsupportedCapabilityError", + "DisabledCapabilityError", + "QueryError", + "InvalidSyntaxError", + "EdgeQLSyntaxError", + "SchemaSyntaxError", + "GraphQLSyntaxError", + "InvalidTypeError", + "InvalidTargetError", + "InvalidLinkTargetError", + "InvalidPropertyTargetError", + "InvalidReferenceError", + "UnknownModuleError", + "UnknownLinkError", + "UnknownPropertyError", + "UnknownUserError", + "UnknownDatabaseError", + "UnknownParameterError", + "SchemaError", + "SchemaDefinitionError", + "InvalidDefinitionError", + "InvalidModuleDefinitionError", + "InvalidLinkDefinitionError", + "InvalidPropertyDefinitionError", + "InvalidUserDefinitionError", + "InvalidDatabaseDefinitionError", + "InvalidOperatorDefinitionError", + "InvalidAliasDefinitionError", + "InvalidFunctionDefinitionError", + "InvalidConstraintDefinitionError", + "InvalidCastDefinitionError", + "DuplicateDefinitionError", + "DuplicateModuleDefinitionError", + "DuplicateLinkDefinitionError", + "DuplicatePropertyDefinitionError", + "DuplicateUserDefinitionError", + "DuplicateDatabaseDefinitionError", + "DuplicateOperatorDefinitionError", + "DuplicateViewDefinitionError", + "DuplicateFunctionDefinitionError", + "DuplicateConstraintDefinitionError", + "DuplicateCastDefinitionError", + "SessionTimeoutError", + "IdleSessionTimeoutError", + "QueryTimeoutError", + "TransactionTimeoutError", + "IdleTransactionTimeoutError", + "ExecutionError", + "InvalidValueError", + "DivisionByZeroError", + "NumericOutOfRangeError", + "AccessPolicyError", + "IntegrityError", + "ConstraintViolationError", + "CardinalityViolationError", + "MissingRequiredError", + "TransactionError", + "TransactionConflictError", + "TransactionSerializationError", + "TransactionDeadlockError", + "ConfigurationError", + "AccessError", + "AuthenticationError", + "AvailabilityError", + "BackendUnavailableError", + "BackendError", + "UnsupportedBackendFeatureError", + "LogMessage", + "WarningMessage", + "ClientError", + "ClientConnectionError", + "ClientConnectionFailedError", + "ClientConnectionFailedTemporarilyError", + "ClientConnectionTimeoutError", + "ClientConnectionClosedError", + "InterfaceError", + "QueryArgumentError", + "MissingArgumentError", + "UnknownArgumentError", + "InvalidArgumentError", + "NoDataError", + "InternalClientError", + ] +) # diff --git a/edgedb/_taskgroup.py b/edgedb/_taskgroup.py index 0a1859d7..f95c9dd8 100644 --- a/edgedb/_taskgroup.py +++ b/edgedb/_taskgroup.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import asyncio import functools @@ -25,10 +25,9 @@ class TaskGroup: - def __init__(self, *, name=None): if name is None: - self._name = f'tg-{_name_counter()}' + self._name = f"tg-{_name_counter()}" else: self._name = str(name) @@ -48,37 +47,37 @@ def get_name(self): return self._name def __repr__(self): - msg = f''5 minutes'; - """) + """ + ) _default_cluster = { - 'proc': p, - 'client': client, - 'con_args': con_args, + "proc": p, + "client": client, + "con_args": con_args, } - if 'tls_cert_file' in data: + if "tls_cert_file" in data: # Keep the temp dir which we also copied the cert from WSL - _default_cluster['_tmpdir'] = tmpdir + _default_cluster["_tmpdir"] = tmpdir atexit.register(client.close) except Exception as e: @@ -201,7 +201,7 @@ class TestCaseMeta(type(unittest.TestCase)): def _iter_methods(bases, ns): for base in bases: for methname in dir(base): - if not methname.startswith('test_'): + if not methname.startswith("test_"): continue meth = getattr(base, methname) @@ -211,7 +211,7 @@ def _iter_methods(bases, ns): yield methname, meth for methname, meth in ns.items(): - if not methname.startswith('test_'): + if not methname.startswith("test_"): continue if not inspect.iscoroutinefunction(meth): @@ -232,14 +232,15 @@ def wrapper(self, *args, __meth__=meth, **kwargs): # than hunting them down every time, simply # retry the test. self.loop.run_until_complete( - __meth__(self, *args, **kwargs)) + __meth__(self, *args, **kwargs) + ) except edgedb.TransactionSerializationError: if try_no == 3: raise else: - self.loop.run_until_complete(self.client.execute( - 'ROLLBACK;' - )) + self.loop.run_until_complete( + self.client.execute("ROLLBACK;") + ) try_no += 1 else: break @@ -257,12 +258,13 @@ def __new__(mcls, name, bases, ns): mcls.add_method(methname, ns, meth) cls = super().__new__(mcls, name, bases, ns) - if not ns.get('BASE_TEST_CLASS') and hasattr(cls, 'get_database_name'): + if not ns.get("BASE_TEST_CLASS") and hasattr(cls, "get_database_name"): dbname = cls.get_database_name() if name in mcls._database_names: raise TypeError( - f'{name} wants duplicate database name: {dbname}') + f"{name} wants duplicate database name: {dbname}" + ) mcls._database_names.add(name) @@ -282,7 +284,7 @@ def tearDownClass(cls): asyncio.set_event_loop(None) def add_fail_notes(self, **kwargs): - if not hasattr(self, 'fail_notes'): + if not hasattr(self, "fail_notes"): self.fail_notes = {} self.fail_notes.update(kwargs) @@ -296,8 +298,7 @@ def annotate(self, **kwargs): raise @contextlib.contextmanager - def assertRaisesRegex(self, exception, regex, msg=None, - **kwargs): + def assertRaisesRegex(self, exception, regex, msg=None, **kwargs): with super().assertRaisesRegex(exception, regex, msg=msg): try: yield @@ -307,9 +308,10 @@ def assertRaisesRegex(self, exception, regex, msg=None, val = getattr(e, attr_name) if val != expected_val: raise self.failureException( - f'{exception.__name__} context attribute ' - f'{attr_name!r} is {val} (expected ' - f'{expected_val!r})') from e + f"{exception.__name__} context attribute " + f"{attr_name!r} is {val} (expected " + f"{expected_val!r})" + ) from e raise def addCleanup(self, func, *args, **kwargs): @@ -318,11 +320,11 @@ def cleanup(): res = func(*args, **kwargs) if inspect.isawaitable(res): self.loop.run_until_complete(res) + super().addCleanup(cleanup) class ClusterTestCase(TestCase): - BASE_TEST_CLASS = True @classmethod @@ -367,15 +369,17 @@ class ConnectedTestCaseMixin: @classmethod def make_test_client( - cls, *, + cls, + *, cluster=None, - database='edgedb', - user='edgedb', - password='test', + database="edgedb", + user="edgedb", + password="test", connection_class=..., ): conargs = cls.get_connect_args( - cluster=cluster, database=database, user=user, password=password) + cluster=cluster, database=database, user=user, password=password + ) if connection_class is ...: connection_class = ( asyncio_client.AsyncIOConnection @@ -389,15 +393,11 @@ def make_test_client( ) @classmethod - def get_connect_args(cls, *, - cluster=None, - database='edgedb', - user='edgedb', - password='test'): - conargs = cls.cluster['con_args'].copy() - conargs.update(dict(user=user, - password=password, - database=database)) + def get_connect_args( + cls, *, cluster=None, database="edgedb", user="edgedb", password="test" + ): + conargs = cls.cluster["con_args"].copy() + conargs.update(dict(user=user, password=password, database=database)) return conargs @classmethod @@ -418,22 +418,21 @@ class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin): def setUp(self): if self.SETUP_METHOD: - self.adapt_call( - self.client.execute(self.SETUP_METHOD)) + self.adapt_call(self.client.execute(self.SETUP_METHOD)) super().setUp() def tearDown(self): try: if self.TEARDOWN_METHOD: - self.adapt_call( - self.client.execute(self.TEARDOWN_METHOD)) + self.adapt_call(self.client.execute(self.TEARDOWN_METHOD)) finally: try: if self.client.connection.is_in_transaction(): raise AssertionError( - 'test connection is still in transaction ' - '*after* the test') + "test connection is still in transaction " + "*after* the test" + ) finally: super().tearDown() @@ -445,11 +444,11 @@ def setUpClass(cls): cls.admin_client = None - class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') + class_set_up = os.environ.get("EDGEDB_TEST_CASES_SET_UP") # Only open an extra admin connection if necessary. if not class_set_up: - script = f'CREATE DATABASE {dbname};' + script = f"CREATE DATABASE {dbname};" cls.admin_client = cls.make_test_client() cls.adapt_call(cls.admin_client.execute(script)) @@ -461,23 +460,27 @@ def setUpClass(cls): # The setup is expected to contain a CREATE MIGRATION, # which needs to be wrapped in a transaction. if cls.is_client_async: + async def execute(): async for tr in cls.client.transaction(): async with tr: await tr.execute(script) + else: + def execute(): for tr in cls.client.transaction(): with tr: tr.execute(script) + cls.adapt_call(execute()) @classmethod def get_database_name(cls): - if cls.__name__.startswith('TestEdgeQL'): - dbname = cls.__name__[len('TestEdgeQL'):] - elif cls.__name__.startswith('Test'): - dbname = cls.__name__[len('Test'):] + if cls.__name__.startswith("TestEdgeQL"): + dbname = cls.__name__[len("TestEdgeQL") :] + elif cls.__name__.startswith("Test"): + dbname = cls.__name__[len("Test") :] else: dbname = cls.__name__ @@ -485,27 +488,26 @@ def get_database_name(cls): @classmethod def get_setup_script(cls): - script = '' + script = "" # Look at all SCHEMA entries and potentially create multiple # modules, but always create the 'test' module. - schema = ['\nmodule test {}'] + schema = ["\nmodule test {}"] for name, val in cls.__dict__.items(): - m = re.match(r'^SCHEMA(?:_(\w+))?', name) + m = re.match(r"^SCHEMA(?:_(\w+))?", name) if m: - module_name = (m.group(1) or 'test').lower().replace( - '__', '.') + module_name = (m.group(1) or "test").lower().replace("__", ".") - with open(val, 'r') as sf: + with open(val) as sf: module = sf.read() - schema.append(f'\nmodule {module_name} {{ {module} }}') + schema.append(f"\nmodule {module_name} {{ {module} }}") # Don't wrap the script into a transaction here, so that # potentially it's easier to stitch multiple such scripts # together in a fashion similar to what `edb inittestdb` does. script += f'\nSTART MIGRATION TO {{ {"".join(schema)} }};' - script += f'\nPOPULATE MIGRATION; \nCOMMIT MIGRATION;' + script += "\nPOPULATE MIGRATION; \nCOMMIT MIGRATION;" if cls.SETUP: if not isinstance(cls.SETUP, (list, tuple)): @@ -514,27 +516,26 @@ def get_setup_script(cls): scripts = cls.SETUP for scr in scripts: - if '\n' not in scr and os.path.exists(scr): - with open(scr, 'rt') as f: + if "\n" not in scr and os.path.exists(scr): + with open(scr) as f: setup = f.read() else: setup = scr - script += '\n' + setup + script += "\n" + setup - return script.strip(' \n') + return script.strip(" \n") @classmethod def tearDownClass(cls): - script = '' + script = "" if cls.TEARDOWN: script = cls.TEARDOWN.strip() try: if script: - cls.adapt_call( - cls.client.execute(script)) + cls.adapt_call(cls.client.execute(script)) finally: try: if cls.is_client_async: @@ -543,13 +544,12 @@ def tearDownClass(cls): cls.client.close() dbname = cls.get_database_name() - script = f'DROP DATABASE {dbname};' + script = f"DROP DATABASE {dbname};" retry = cls.TEARDOWN_RETRY_DROP_DB for i in range(retry): try: - cls.adapt_call( - cls.admin_client.execute(script)) + cls.adapt_call(cls.admin_client.execute(script)) except edgedb.errors.ExecutionError: if i < retry - 1: time.sleep(0.1) @@ -559,15 +559,14 @@ def tearDownClass(cls): break except Exception: - log.exception('error running teardown') + log.exception("error running teardown") # skip the exception so that original error is shown instead # of finalizer error finally: try: if cls.admin_client is not None: if cls.is_client_async: - cls.adapt_call( - cls.admin_client.aclose()) + cls.adapt_call(cls.admin_client.aclose()) else: cls.admin_client.close() finally: @@ -597,6 +596,7 @@ def gen_lock_key(): return os.getpid() * 1000 + _lock_cnt -if os.environ.get('USE_UVLOOP'): +if os.environ.get("USE_UVLOOP"): import uvloop + uvloop.install() diff --git a/edgedb/_version.py b/edgedb/_version.py index e2475748..a71bcf1c 100644 --- a/edgedb/_version.py +++ b/edgedb/_version.py @@ -27,5 +27,6 @@ # The release automation will: build and test the packages for the # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. +from __future__ import annotations -__version__ = '2.0.0a1' +__version__ = "2.0.0a1" diff --git a/edgedb/abstract.py b/edgedb/abstract.py index 158a269a..7efaa2d0 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -18,13 +18,12 @@ from __future__ import annotations + import abc import dataclasses import typing -from . import describe -from . import enums -from . import options +from . import describe, enums, options from .protocol import protocol __all__ = ( @@ -43,8 +42,8 @@ class QueryWithArgs(typing.NamedTuple): query: str - args: typing.Tuple - kwargs: typing.Dict[str, typing.Any] + args: tuple + kwargs: dict[str, typing.Any] class QueryCache(typing.NamedTuple): @@ -62,27 +61,27 @@ class QueryContext(typing.NamedTuple): query: QueryWithArgs cache: QueryCache query_options: QueryOptions - retry_options: typing.Optional[options.RetryOptions] - state: typing.Optional[options.State] + retry_options: options.RetryOptions | None + state: options.State | None class ExecuteContext(typing.NamedTuple): query: QueryWithArgs cache: QueryCache - state: typing.Optional[options.State] + state: options.State | None @dataclasses.dataclass class DescribeContext: query: str - state: typing.Optional[options.State] + state: options.State | None inject_type_names: bool @dataclasses.dataclass class DescribeResult: - input_type: typing.Optional[describe.AnyType] - output_type: typing.Optional[describe.AnyType] + input_type: describe.AnyType | None + output_type: describe.AnyType | None output_cardinality: enums.Cardinality capabilities: enums.Capability @@ -126,7 +125,7 @@ class BaseReadOnlyExecutor(abc.ABC): def _get_query_cache(self) -> QueryCache: ... - def _get_retry_options(self) -> typing.Optional[options.RetryOptions]: + def _get_retry_options(self) -> options.RetryOptions | None: return None def _get_state(self) -> options.State: @@ -134,7 +133,7 @@ def _get_state(self) -> options.State: class ReadOnlyExecutor(BaseReadOnlyExecutor): - """Subclasses can execute *at least* read-only queries""" + """Subclasses can execute *at least* read-only queries.""" __slots__ = () @@ -143,81 +142,93 @@ def _query(self, query_context: QueryContext): ... def query(self, query: str, *args, **kwargs) -> list: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) - - def query_single( - self, query: str, *args, **kwargs - ) -> typing.Union[typing.Any, None]: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) + + def query_single(self, query: str, *args, **kwargs) -> typing.Any | None: + return self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) def query_json(self, query: str, *args, **kwargs) -> str: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) def query_single_json(self, query: str, *args, **kwargs) -> str: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) def query_required_single_json(self, query: str, *args, **kwargs) -> str: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) @abc.abstractmethod def _execute(self, execute_context: ExecuteContext): ... def execute(self, commands: str, *args, **kwargs) -> None: - self._execute(ExecuteContext( - query=QueryWithArgs(commands, args, kwargs), - cache=self._get_query_cache(), - state=self._get_state(), - )) + self._execute( + ExecuteContext( + query=QueryWithArgs(commands, args, kwargs), + cache=self._get_query_cache(), + state=self._get_state(), + ) + ) class Executor(ReadOnlyExecutor): - """Subclasses can execute both read-only and modification queries""" + """Subclasses can execute both read-only and modification queries.""" __slots__ = () class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor): - """Subclasses can execute *at least* read-only queries""" + """Subclasses can execute *at least* read-only queries.""" __slots__ = () @@ -226,82 +237,90 @@ async def _query(self, query_context: QueryContext): ... async def query(self, query: str, *args, **kwargs) -> list: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return await self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) async def query_single(self, query: str, *args, **kwargs) -> typing.Any: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return await self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) async def query_required_single( - self, - query: str, - *args, - **kwargs + self, query: str, *args, **kwargs ) -> typing.Any: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return await self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) async def query_json(self, query: str, *args, **kwargs) -> str: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return await self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) async def query_single_json(self, query: str, *args, **kwargs) -> str: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return await self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) async def query_required_single_json( - self, - query: str, - *args, - **kwargs + self, query: str, *args, **kwargs ) -> str: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - )) + return await self._query( + QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + ) + ) @abc.abstractmethod async def _execute(self, execute_context: ExecuteContext) -> None: ... async def execute(self, commands: str, *args, **kwargs) -> None: - await self._execute(ExecuteContext( - query=QueryWithArgs(commands, args, kwargs), - cache=self._get_query_cache(), - state=self._get_state(), - )) + await self._execute( + ExecuteContext( + query=QueryWithArgs(commands, args, kwargs), + cache=self._get_query_cache(), + state=self._get_state(), + ) + ) class AsyncIOExecutor(AsyncIOReadOnlyExecutor): - """Subclasses can execute both read-only and modification queries""" + """Subclasses can execute both read-only and modification queries.""" __slots__ = () diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index 45d8487d..f0386256 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -15,26 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import asyncio import contextlib import logging import socket import ssl -import typing -from . import abstract -from . import base_client -from . import con_utils -from . import errors -from . import transaction +from . import abstract, base_client, con_utils, errors, transaction from .protocol import asyncio_proto - -__all__ = ( - 'create_async_client', 'AsyncIOClient' -) +__all__ = ("create_async_client", "AsyncIOClient") logger = logging.getLogger(__name__) @@ -95,7 +87,7 @@ async def _connect_addr(self, addr): raise con_utils.wrap_error(e) from e else: con_utils.check_alpn_protocol( - tr.get_extra_info('ssl_object') + tr.get_extra_info("ssl_object") ) except socket.gaierror as e: # All name resolution errors are considered temporary @@ -145,21 +137,22 @@ async def wait_until_released(self, timeout=None): class _AsyncIOPoolImpl(base_client.BasePoolImpl): - __slots__ = ('_loop',) + __slots__ = ("_loop",) _holder_class = _PoolConnectionHolder def __init__( self, connect_args, *, - max_concurrency: typing.Optional[int], + max_concurrency: int | None, connection_class, ): if not issubclass(connection_class, AsyncIOConnection): raise TypeError( - f'connection_class is expected to be a subclass of ' - f'edgedb.asyncio_client.AsyncIOConnection, ' - f'got {connection_class}') + f"connection_class is expected to be a subclass of " + f"edgedb.asyncio_client.AsyncIOConnection, " + f"got {connection_class}" + ) self._loop = None super().__init__( connect_args, @@ -199,20 +192,18 @@ async def _acquire_impl(): return proxy if self._closing: - raise errors.InterfaceError('pool is closing') + raise errors.InterfaceError("pool is closing") if timeout is None: return await _acquire_impl() else: - return await asyncio.wait_for( - _acquire_impl(), timeout=timeout) + return await asyncio.wait_for(_acquire_impl(), timeout=timeout) async def _release(self, holder): - if not isinstance(holder._con, AsyncIOConnection): raise errors.InterfaceError( - f'release() received invalid connection: ' - f'{holder._con!r} does not belong to any connection pool' + f"release() received invalid connection: " + f"{holder._con!r} does not belong to any connection pool" ) timeout = None @@ -244,14 +235,13 @@ async def aclose(self): try: warning_callback = self._loop.call_later( - 60, self._warn_on_long_close) + 60, self._warn_on_long_close + ) - release_coros = [ - ch.wait_until_released() for ch in self._holders] + release_coros = [ch.wait_until_released() for ch in self._holders] await asyncio.gather(*release_coros) - close_coros = [ - ch.close() for ch in self._holders] + close_coros = [ch.close() for ch in self._holders] await asyncio.gather(*close_coros) except (Exception, asyncio.CancelledError): @@ -265,14 +255,14 @@ async def aclose(self): def _warn_on_long_close(self): logger.warning( - 'AsyncIOClient.aclose() is taking over 60 seconds to complete. ' - 'Check if you have any unreleased connections left. ' - 'Use asyncio.wait_for() to set a timeout for ' - 'AsyncIOClient.aclose().') + "AsyncIOClient.aclose() is taking over 60 seconds to complete. " + "Check if you have any unreleased connections left. " + "Use asyncio.wait_for() to set a timeout for " + "AsyncIOClient.aclose()." + ) class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor): - __slots__ = ("_managed", "_locked") def __init__(self, retry, client, iteration): @@ -283,7 +273,8 @@ def __init__(self, retry, client, iteration): async def __aenter__(self): if self._managed: raise errors.InterfaceError( - 'cannot enter context: already in an `async with` block') + "cannot enter context: already in an `async with` block" + ) self._managed = True return self @@ -323,7 +314,6 @@ def _exclusive(self): class AsyncIORetry(transaction.BaseRetry): - def __aiter__(self): return self @@ -384,35 +374,36 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def _describe_query( self, query: str, *, inject_type_names: bool = False ) -> abstract.DescribeResult: - return await self._describe(abstract.DescribeContext( - query=query, - state=self._get_state(), - inject_type_names=inject_type_names, - )) + return await self._describe( + abstract.DescribeContext( + query=query, + state=self._get_state(), + inject_type_names=inject_type_names, + ) + ) def create_async_client( dsn=None, *, max_concurrency=None, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - secret_key: str = None, - database: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, + host: str | None = None, + port: int | None = None, + credentials: str | None = None, + credentials_file: str | None = None, + user: str | None = None, + password: str | None = None, + secret_key: str | None = None, + database: str | None = None, + tls_ca: str | None = None, + tls_ca_file: str | None = None, + tls_security: str | None = None, wait_until_available: int = 30, timeout: int = 10, ): return AsyncIOClient( connection_class=AsyncIOConnection, max_concurrency=max_concurrency, - # connect arguments dsn=dsn, host=host, diff --git a/edgedb/base_client.py b/edgedb/base_client.py index 331562e6..c7955b5a 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -15,31 +15,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import abc import random import time import typing -from . import abstract -from . import con_utils -from . import enums -from . import errors +from . import abstract, con_utils, enums, errors from . import options as _options from .protocol import protocol - -BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection') +BaseConnection_T = typing.TypeVar("BaseConnection_T", bound="BaseConnection") class BaseConnection(metaclass=abc.ABCMeta): _protocol: typing.Any - _addr: typing.Optional[typing.Union[str, typing.Tuple[str, int]]] - _addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]] + _addr: str | tuple[str, int] | None + _addrs: typing.Iterable[str | tuple[str, int]] _config: con_utils.ClientConfiguration _params: con_utils.ResolvedConnectConfig - _log_listeners: typing.Set[ + _log_listeners: set[ typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None] ] __slots__ = ( @@ -55,7 +51,7 @@ class BaseConnection(metaclass=abc.ABCMeta): def __init__( self, - addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]], + addrs: typing.Iterable[str | tuple[str, int]], config: con_utils.ClientConfiguration, params: con_utils.ResolvedConnectConfig, ): @@ -78,7 +74,7 @@ def _on_log_message(self, msg): def connected_addr(self): return self._addr - def _get_last_status(self) -> typing.Optional[str]: + def _get_last_status(self) -> str | None: if self._protocol is None: return None status = self._protocol.last_status @@ -94,8 +90,9 @@ def _cleanup(self): def add_log_listener( self: BaseConnection_T, - callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], - None] + callback: typing.Callable[ + [BaseConnection_T, errors.EdgeDBMessage], None + ], ) -> None: """Add a listener for EdgeDB log messages. @@ -108,8 +105,9 @@ def add_log_listener( def remove_log_listener( self: BaseConnection_T, - callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], - None] + callback: typing.Callable[ + [BaseConnection_T, errors.EdgeDBMessage], None + ], ) -> None: """Remove a listening callback for log messages.""" self._log_listeners.discard(callback) @@ -151,9 +149,8 @@ async def connect(self, *, single_attempt=False): f" {self._config.connect_timeout} sec" ) from e except errors.ClientConnectionError as e: - if ( - e.has_tag(errors.SHOULD_RECONNECT) and - (iteration == 1 or time.monotonic() < max_time) + if e.has_tag(errors.SHOULD_RECONNECT) and ( + iteration == 1 or time.monotonic() < max_time ): continue nice_err = e.__class__( @@ -162,7 +159,8 @@ async def connect(self, *, single_attempt=False): addr, attempts=iteration, duration=time.monotonic() - start, - )) + ) + ) raise nice_err from e.__cause__ else: return @@ -188,7 +186,8 @@ async def privileged_execute( allow_capabilities=enums.Capability.ALL, state=( execute_context.state.as_dict() - if execute_context.state else None + if execute_context.state + else None ), ) @@ -199,7 +198,7 @@ def is_in_transaction(self) -> bool: """ return self._protocol.is_in_transaction() - def get_settings(self) -> typing.Dict[str, typing.Any]: + def get_settings(self) -> dict[str, typing.Any]: return self._protocol.get_settings() async def raw_query(self, query_context: abstract.QueryContext): @@ -255,9 +254,8 @@ async def raw_query(self, query_context: abstract.QueryContext): # A query is read-only if it has no capabilities i.e. # capabilities == 0. Read-only queries are safe to retry. # Explicit transaction conflicts as well. - if ( - capabilities != 0 - and not isinstance(e, errors.TransactionConflictError) + if capabilities != 0 and not isinstance( + e, errors.TransactionConflictError ): raise e rule = query_context.retry_options.get_rule_for_exception(e) @@ -286,7 +284,8 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None: allow_capabilities=enums.Capability.EXECUTE, state=( execute_context.state.as_dict() - if execute_context.state else None + if execute_context.state + else None ), ) @@ -299,7 +298,8 @@ async def describe( inline_typenames=describe_context.inject_type_names, state=( describe_context.state.as_dict() - if describe_context.state else None + if describe_context.state + else None ), ) return abstract.DescribeResult( @@ -318,13 +318,15 @@ def terminate(self): def __repr__(self): if self.is_closed(): - return '<{classname} [closed] {id:#x}>'.format( - classname=self.__class__.__name__, id=id(self)) + return "<{classname} [closed] {id:#x}>".format( + classname=self.__class__.__name__, id=id(self) + ) else: - return '<{classname} [connected to {addr}] {id:#x}>'.format( + return "<{classname} [connected to {addr}] {id:#x}>".format( classname=self.__class__.__name__, addr=self.connected_addr(), - id=id(self)) + id=id(self), + ) class PoolConnectionHolder(abc.ABC): @@ -338,7 +340,6 @@ class PoolConnectionHolder(abc.ABC): _event_class = NotImplemented def __init__(self, pool): - self._pool = pool self._con = None @@ -359,8 +360,9 @@ async def wait_until_released(self, timeout=None): async def connect(self): if self._con is not None: raise errors.InternalClientError( - 'PoolConnectionHolder.connect() called while another ' - 'connection already exists') + "PoolConnectionHolder.connect() called while another " + "connection already exists" + ) self._con = await self._pool._get_new_connection() assert self._con._holder is None @@ -386,8 +388,9 @@ async def acquire(self) -> BaseConnection: async def release(self, timeout): if self._release_event.is_set(): raise errors.InternalClientError( - 'PoolConnectionHolder.release() called on ' - 'a free connection holder') + "PoolConnectionHolder.release() called on " + "a free connection holder" + ) if self._con.is_closed(): # This is usually the case when the connection is broken rather @@ -461,7 +464,7 @@ def __init__( connect_args, connection_factory, *, - max_concurrency: typing.Optional[int], + max_concurrency: int | None, ): self._connection_factory = connection_factory self._connect_args = connect_args @@ -470,7 +473,7 @@ def __init__( if max_concurrency is not None and max_concurrency <= 0: raise ValueError( - 'max_concurrency is expected to be greater than zero' + "max_concurrency is expected to be greater than zero" ) self._user_max_concurrency = max_concurrency @@ -519,7 +522,7 @@ def query_cache(self): def _resize_holder_pool(self): resize_diff = self._max_concurrency - len(self._holders) - if (resize_diff > 0): + if resize_diff > 0: if self._queue.maxsize != self._max_concurrency: self._set_queue_maxsize(self._max_concurrency) @@ -559,7 +562,6 @@ def set_connect_args(self, dsn=None, **connect_kwargs): Keyword arguments for the :func:`~edgedb.asyncio_client.create_async_client` function. """ - connect_kwargs["dsn"] = dsn self._connect_args = connect_kwargs self._codecs_registry = protocol.CodecsRegistry() @@ -586,7 +588,8 @@ async def _get_first_connection(self): if self._user_max_concurrency is None: suggested_concurrency = con.get_settings().get( - 'suggested_pool_concurrency') + "suggested_pool_concurrency" + ) if suggested_concurrency: self._max_concurrency = suggested_concurrency self._resize_holder_pool() @@ -610,11 +613,10 @@ async def _get_new_connection(self): return con async def release(self, connection): - if not isinstance(connection, BaseConnection): raise errors.InterfaceError( - f'BasePoolImpl.release() received invalid connection: ' - f'{connection!r} does not belong to any connection pool' + f"BasePoolImpl.release() received invalid connection: " + f"{connection!r} does not belong to any connection pool" ) ch = connection._holder @@ -624,8 +626,8 @@ async def release(self, connection): if ch._pool is not self: raise errors.InterfaceError( - f'BasePoolImpl.release() received invalid connection: ' - f'{connection!r} is not a member of this pool' + f"BasePoolImpl.release() received invalid connection: " + f"{connection!r} is not a member of this pool" ) return await self._release(ch) @@ -666,19 +668,19 @@ def __init__( self, *, connection_class, - max_concurrency: typing.Optional[int], + max_concurrency: int | None, dsn=None, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - secret_key: str = None, - database: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, + host: str | None = None, + port: int | None = None, + credentials: str | None = None, + credentials_file: str | None = None, + user: str | None = None, + password: str | None = None, + secret_key: str | None = None, + database: str | None = None, + tls_ca: str | None = None, + tls_ca_file: str | None = None, + tls_security: str | None = None, wait_until_available: int = 30, timeout: int = 10, **kwargs, @@ -719,7 +721,7 @@ def _get_query_cache(self) -> abstract.QueryCache: query_cache=self._impl.query_cache, ) - def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]: + def _get_retry_options(self) -> _options.RetryOptions | None: return self._options.retry_options def _get_state(self) -> _options.State: @@ -728,13 +730,11 @@ def _get_state(self) -> _options.State: @property def max_concurrency(self) -> int: """Max number of connections in the pool.""" - return self._impl.get_max_concurrency() @property def free_size(self) -> int: """Number of available connections in the pool.""" - return self._impl.get_free_size() async def _query(self, query_context: abstract.QueryContext): diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 7eb761b9..ba0fe2c9 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import contextlib import datetime @@ -24,16 +24,10 @@ import ssl import threading import time -import typing -from . import abstract -from . import base_client -from . import con_utils -from . import errors -from . import transaction +from . import abstract, base_client, con_utils, errors, transaction from .protocol import blocking_proto - DEFAULT_PING_BEFORE_IDLE_TIMEOUT = datetime.timedelta(seconds=5) MINIMUM_PING_WAIT_TIME = datetime.timedelta(seconds=1) @@ -119,8 +113,12 @@ async def sleep(self, seconds): def is_closed(self): proto = self._protocol - return not (proto and proto.sock is not None and - proto.sock.fileno() >= 0 and proto.connected) + return not ( + proto + and proto.sock is not None + and proto.sock.fileno() >= 0 + and proto.connected + ) async def close(self, timeout=None): """Send graceful termination message wait for connection to drop.""" @@ -176,14 +174,15 @@ def __init__( self, connect_args, *, - max_concurrency: typing.Optional[int], + max_concurrency: int | None, connection_class, ): if not issubclass(connection_class, BlockingIOConnection): raise TypeError( - f'connection_class is expected to be a subclass of ' - f'edgedb.blocking_client.BlockingIOConnection, ' - f'got {connection_class}') + f"connection_class is expected to be a subclass of " + f"edgedb.blocking_client.BlockingIOConnection, " + f"got {connection_class}" + ) super().__init__( connect_args, connection_class, @@ -209,7 +208,7 @@ async def acquire(self, timeout=None): self._ensure_initialized() if self._closing: - raise errors.InterfaceError('pool is closing') + raise errors.InterfaceError("pool is closing") ch = self._queue.get(timeout=timeout) try: @@ -226,8 +225,8 @@ async def acquire(self, timeout=None): async def _release(self, holder): if not isinstance(holder._con, BlockingIOConnection): raise errors.InterfaceError( - f'release() received invalid connection: ' - f'{holder._con!r} does not belong to any connection pool' + f"release() received invalid connection: " + f"{holder._con!r} does not belong to any connection pool" ) timeout = None @@ -271,7 +270,6 @@ async def close(self, timeout=None): class Iteration(transaction.BaseTransaction, abstract.Executor): - __slots__ = ("_managed", "_lock") def __init__(self, retry, client, iteration): @@ -283,7 +281,8 @@ def __enter__(self): with self._exclusive(): if self._managed: raise errors.InterfaceError( - 'cannot enter context: already in a `with` block') + "cannot enter context: already in a `with` block" + ) self._managed = True return self @@ -322,7 +321,6 @@ def _exclusive(self): class Retry(transaction.BaseRetry): - def __iter__(self): return self @@ -398,35 +396,38 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _describe_query( self, query: str, *, inject_type_names: bool = False ) -> abstract.DescribeResult: - return self._iter_coroutine(self._describe(abstract.DescribeContext( - query=query, - state=self._get_state(), - inject_type_names=inject_type_names, - ))) + return self._iter_coroutine( + self._describe( + abstract.DescribeContext( + query=query, + state=self._get_state(), + inject_type_names=inject_type_names, + ) + ) + ) def create_client( dsn=None, *, max_concurrency=None, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - secret_key: str = None, - database: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, + host: str | None = None, + port: int | None = None, + credentials: str | None = None, + credentials_file: str | None = None, + user: str | None = None, + password: str | None = None, + secret_key: str | None = None, + database: str | None = None, + tls_ca: str | None = None, + tls_ca_file: str | None = None, + tls_security: str | None = None, wait_until_available: int = 30, timeout: int = 10, ): return Client( connection_class=BlockingIOConnection, max_concurrency=max_concurrency, - # connect arguments dsn=dsn, host=host, diff --git a/edgedb/codegen/__main__.py b/edgedb/codegen/__main__.py index 1b07b01e..f8406414 100644 --- a/edgedb/codegen/__main__.py +++ b/edgedb/codegen/__main__.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations from .cli import main diff --git a/edgedb/codegen/cli.py b/edgedb/codegen/cli.py index e8149ae6..19ffbe0c 100644 --- a/edgedb/codegen/cli.py +++ b/edgedb/codegen/cli.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import argparse import sys @@ -66,7 +66,7 @@ def error(self, message): choices=["blocking", "async"], nargs="*", default=["async"], - help="Choose one or more targets to generate code (default is async)." + help="Choose one or more targets to generate code (default is async).", ) if sys.version_info[:2] >= (3, 9): parser.add_argument( @@ -88,7 +88,7 @@ def error(self, message): action="store_false", default=False, help="Add a mixin to generated dataclasses " - "to skip Pydantic validation (default is to add the mixin).", + "to skip Pydantic validation (default is to add the mixin).", ) diff --git a/edgedb/codegen/generator.py b/edgedb/codegen/generator.py index 626e735e..3be92719 100644 --- a/edgedb/codegen/generator.py +++ b/edgedb/codegen/generator.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations import argparse import getpass @@ -26,11 +27,9 @@ import typing import edgedb -from edgedb import abstract -from edgedb import describe -from edgedb.con_utils import find_edgedb_project_dir +from edgedb import abstract, describe from edgedb.color import get_color - +from edgedb.con_utils import find_edgedb_project_dir C = get_color() SYS_VERSION_INFO = os.getenv("EDGEDB_PYTHON_CODEGEN_PY_VER") @@ -262,9 +261,7 @@ def _generate_single_file(self, suffix: str): with target.open("w") as f: f.write(buf.getvalue()) - def _write_comments( - self, f: io.TextIOBase, src: typing.List[pathlib.Path] - ): + def _write_comments(self, f: io.TextIOBase, src: list[pathlib.Path]): src_str = map( lambda p: repr(p.relative_to(self._project_dir).as_posix()), src ) @@ -392,7 +389,7 @@ def _generate( return buf.getvalue() def _generate_code( - self, type_: typing.Optional[describe.AnyType], name_hint: str + self, type_: describe.AnyType | None, name_hint: str ) -> str: if type_ is None: return "None" @@ -463,9 +460,9 @@ def _generate_code( for el_name, el_code in link_props: print(f"{INDENT}@typing.overload", file=buf) print( - f'{INDENT}def __getitem__' + f"{INDENT}def __getitem__" f'(self, key: {typing_literal}["{el_name}"]) ' - f'-> {el_code}:', + f"-> {el_code}:", file=buf, ) print(f"{INDENT}{INDENT}...", file=buf) @@ -474,9 +471,7 @@ def _generate_code( f"{INDENT}def __getitem__(self, key: str) -> typing.Any:", file=buf, ) - print( - f"{INDENT}{INDENT}raise NotImplementedError", file=buf - ) + print(f"{INDENT}{INDENT}raise NotImplementedError", file=buf) self._defs[rv] = buf.getvalue().strip() @@ -513,7 +508,7 @@ def _generate_code( def _generate_code_with_cardinality( self, - type_: typing.Optional[describe.AnyType], + type_: describe.AnyType | None, name_hint: str, cardinality: edgedb.Cardinality, keyword_argument: bool = False, @@ -556,7 +551,7 @@ def _snake_to_camel(self, name: str) -> str: return name def _to_unique_idents( - self, names: typing.Iterable[typing.Tuple[str, str]] + self, names: typing.Iterable[tuple[str, str]] ) -> typing.Iterator[str]: dedup = set() for name in names: diff --git a/edgedb/color.py b/edgedb/color.py index 1e95c7bf..50bf1562 100644 --- a/edgedb/color.py +++ b/edgedb/color.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys import warnings @@ -30,15 +32,15 @@ def get_color() -> Color: except Exception: use_color = False if use_color: - COLOR.HEADER = '\033[95m' - COLOR.BLUE = '\033[94m' - COLOR.CYAN = '\033[96m' - COLOR.GREEN = '\033[92m' - COLOR.WARNING = '\033[93m' - COLOR.FAIL = '\033[91m' - COLOR.ENDC = '\033[0m' - COLOR.BOLD = '\033[1m' - COLOR.UNDERLINE = '\033[4m' + COLOR.HEADER = "\033[95m" + COLOR.BLUE = "\033[94m" + COLOR.CYAN = "\033[96m" + COLOR.GREEN = "\033[92m" + COLOR.WARNING = "\033[93m" + COLOR.FAIL = "\033[91m" + COLOR.ENDC = "\033[0m" + COLOR.BOLD = "\033[1m" + COLOR.UNDERLINE = "\033[4m" return COLOR @@ -49,9 +51,7 @@ def get_color() -> Color: "auto": lambda: sys.stderr.isatty(), "enabled": True, "disabled": False, - }[ - os.getenv("EDGEDB_COLOR_OUTPUT", "default") - ] + }[os.getenv("EDGEDB_COLOR_OUTPUT", "default")] except KeyError: warnings.warn( "EDGEDB_COLOR_OUTPUT can only be one of: " diff --git a/edgedb/con_utils.py b/edgedb/con_utils.py index d8e29881..cfa745c1 100644 --- a/edgedb/con_utils.py +++ b/edgedb/con_utils.py @@ -15,11 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import base64 import binascii import errno +import hashlib import json import os import re @@ -27,12 +28,9 @@ import typing import urllib.parse import warnings -import hashlib -from . import errors from . import credentials as cred_utils -from . import platform - +from . import errors, platform EDGEDB_PORT = 5656 ERRNO_RE = re.compile(r"\[Errno (\d+)\]") @@ -42,54 +40,55 @@ ConnectionResetError, FileNotFoundError, ) -TEMPORARY_ERROR_CODES = frozenset({ - errno.ECONNREFUSED, - errno.ECONNABORTED, - errno.ECONNRESET, - errno.ENOENT, -}) - -ISO_SECONDS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)S') -ISO_MINUTES_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M') -ISO_HOURS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)H') -ISO_UNITLESS_HOURS_RE = re.compile(r'^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$') -ISO_DAYS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)D') -ISO_WEEKS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)W') -ISO_MONTHS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M') -ISO_YEARS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)Y') +TEMPORARY_ERROR_CODES = frozenset( + { + errno.ECONNREFUSED, + errno.ECONNABORTED, + errno.ECONNRESET, + errno.ENOENT, + } +) + +ISO_SECONDS_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)S") +ISO_MINUTES_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M") +ISO_HOURS_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)H") +ISO_UNITLESS_HOURS_RE = re.compile(r"^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$") +ISO_DAYS_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)D") +ISO_WEEKS_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)W") +ISO_MONTHS_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M") +ISO_YEARS_RE = re.compile(r"(-?\d+|-?\d+\.\d*|-?\d*\.\d+)Y") HUMAN_HOURS_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:h(\s|\d|\.|$)|hours?(\s|$))', + r"((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:h(\s|\d|\.|$)|hours?(\s|$))", ) HUMAN_MINUTES_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:m(\s|\d|\.|$)|minutes?(\s|$))', + r"((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:m(\s|\d|\.|$)|minutes?(\s|$))", ) HUMAN_SECONDS_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:s(\s|\d|\.|$)|seconds?(\s|$))', + r"((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:s(\s|\d|\.|$)|seconds?(\s|$))", ) HUMAN_MS_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:ms(\s|\d|\.|$)|milliseconds?(\s|$))', + r"((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:ms(\s|\d|\.|$)|milliseconds?(\s|$))", ) HUMAN_US_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))', + r"((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))", ) INSTANCE_NAME_RE = re.compile( - r'^(\w(?:-?\w)*)$', + r"^(\w(?:-?\w)*)$", re.ASCII, ) CLOUD_INSTANCE_NAME_RE = re.compile( - r'^([A-Za-z0-9](?:-?[A-Za-z0-9])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$', + r"^([A-Za-z0-9](?:-?[A-Za-z0-9])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$", re.ASCII, ) DSN_RE = re.compile( - r'^[a-z]+://', + r"^[a-z]+://", re.IGNORECASE, ) DOMAIN_LABEL_MAX_LENGTH = 63 class ClientConfiguration(typing.NamedTuple): - connect_timeout: float command_timeout: float wait_until_available: float @@ -101,8 +100,10 @@ def _validate_port_spec(hosts, port): # match that of the host list. if len(port) != len(hosts): raise errors.InterfaceError( - 'could not match {} port numbers to {} hosts'.format( - len(port), len(hosts))) + "could not match {} port numbers to {} hosts".format( + len(port), len(hosts) + ) + ) else: port = [port for _ in range(len(hosts))] @@ -110,9 +111,9 @@ def _validate_port_spec(hosts, port): def _parse_hostlist(hostlist, port): - if ',' in hostlist: + if "," in hostlist: # A comma-separated list of host addresses. - hostspecs = hostlist.split(',') + hostspecs = hostlist.split(",") else: hostspecs = [hostlist] @@ -120,10 +121,10 @@ def _parse_hostlist(hostlist, port): hostlist_ports = [] if not port: - portspec = os.environ.get('EDGEDB_PORT') + portspec = os.environ.get("EDGEDB_PORT") if portspec: - if ',' in portspec: - default_port = [int(p) for p in portspec.split(',')] + if "," in portspec: + default_port = [int(p) for p in portspec.split(",")] else: default_port = int(portspec) else: @@ -135,7 +136,7 @@ def _parse_hostlist(hostlist, port): port = _validate_port_spec(hostspecs, port) for i, hostspec in enumerate(hostspecs): - addr, _, hostspec_port = hostspec.partition(':') + addr, _, hostspec_port = hostspec.partition(":") hosts.append(addr) if not port: @@ -152,15 +153,15 @@ def _parse_hostlist(hostlist, port): def _hash_path(path): path = os.path.realpath(path) - if platform.IS_WINDOWS and not path.startswith('\\\\'): - path = '\\\\?\\' + path - return hashlib.sha1(str(path).encode('utf-8')).hexdigest() + if platform.IS_WINDOWS and not path.startswith("\\\\"): + path = "\\\\?\\" + path + return hashlib.sha1(str(path).encode("utf-8")).hexdigest() def _stash_path(path): base_name = os.path.basename(path) - dir_name = base_name + '-' + _hash_path(path) - return platform.search_config_dir('projects', dir_name) + dir_name = base_name + "-" + _hash_path(path) + return platform.search_config_dir("projects", dir_name) def _validate_tls_security(val: str) -> str: @@ -207,51 +208,50 @@ class ResolvedConnectConfig: server_settings = {} def _set_param(self, param, value, source, validator=None): - param_name = '_' + param + param_name = "_" + param if getattr(self, param_name) is None: - setattr(self, param_name + '_source', source) + setattr(self, param_name + "_source", source) if value is not None: setattr( - self, - param_name, - validator(value) if validator else value + self, param_name, validator(value) if validator else value ) def set_host(self, host, source): - self._set_param('host', host, source, _validate_host) + self._set_param("host", host, source, _validate_host) def set_port(self, port, source): - self._set_param('port', port, source, _validate_port) + self._set_param("port", port, source, _validate_port) def set_database(self, database, source): - self._set_param('database', database, source, _validate_database) + self._set_param("database", database, source, _validate_database) def set_user(self, user, source): - self._set_param('user', user, source, _validate_user) + self._set_param("user", user, source, _validate_user) def set_password(self, password, source): - self._set_param('password', password, source) + self._set_param("password", password, source) def set_secret_key(self, secret_key, source): - self._set_param('secret_key', secret_key, source) + self._set_param("secret_key", secret_key, source) def set_tls_ca_data(self, ca_data, source): - self._set_param('tls_ca_data', ca_data, source) + self._set_param("tls_ca_data", ca_data, source) def set_tls_ca_file(self, ca_file, source): def read_ca_file(file_path): with open(file_path) as f: return f.read() - self._set_param('tls_ca_data', ca_file, source, read_ca_file) + self._set_param("tls_ca_data", ca_file, source, read_ca_file) def set_tls_security(self, security, source): - self._set_param('tls_security', security, source, - _validate_tls_security) + self._set_param( + "tls_security", security, source, _validate_tls_security + ) def set_wait_until_available(self, wait_until_available, source): self._set_param( - 'wait_until_available', + "wait_until_available", wait_until_available, source, _validate_wait_until_available, @@ -264,17 +264,17 @@ def add_server_settings(self, server_settings): @property def address(self): return ( - self._host if self._host else 'localhost', - self._port if self._port else 5656 + self._host if self._host else "localhost", + self._port if self._port else 5656, ) @property def database(self): - return self._database if self._database else 'edgedb' + return self._database if self._database else "edgedb" @property def user(self): - return self._user if self._user else 'edgedb' + return self._user if self._user else "edgedb" @property def password(self): @@ -286,29 +286,31 @@ def secret_key(self): @property def tls_security(self): - tls_security = self._tls_security or 'default' - security = os.environ.get('EDGEDB_CLIENT_SECURITY') or 'default' - if security not in {'default', 'insecure_dev_mode', 'strict'}: + tls_security = self._tls_security or "default" + security = os.environ.get("EDGEDB_CLIENT_SECURITY") or "default" + if security not in {"default", "insecure_dev_mode", "strict"}: raise ValueError( - f'environment variable EDGEDB_CLIENT_SECURITY should be ' - f'one of strict, insecure_dev_mode or default, ' - f'got: {security!r}') + f"environment variable EDGEDB_CLIENT_SECURITY should be " + f"one of strict, insecure_dev_mode or default, " + f"got: {security!r}" + ) - if security == 'default': + if security == "default": pass - elif security == 'insecure_dev_mode': - if tls_security == 'default': - tls_security = 'insecure' - elif security == 'strict': - if tls_security == 'default': - tls_security = 'strict' - elif tls_security in {'no_host_verification', 'insecure'}: + elif security == "insecure_dev_mode": + if tls_security == "default": + tls_security = "insecure" + elif security == "strict": + if tls_security == "default": + tls_security = "strict" + elif tls_security in {"no_host_verification", "insecure"}: raise ValueError( - f'EDGEDB_CLIENT_SECURITY=strict but ' - f'tls_security={tls_security}, tls_security must be ' - f'set to strict when EDGEDB_CLIENT_SECURITY is strict') + f"EDGEDB_CLIENT_SECURITY=strict but " + f"tls_security={tls_security}, tls_security must be " + f"set to strict when EDGEDB_CLIENT_SECURITY is strict" + ) - if tls_security != 'default': + if tls_security != "default": return tls_security if self._tls_ca_data is not None: @@ -320,19 +322,18 @@ def tls_security(self): @property def ssl_ctx(self): - if (self._ssl_ctx): + if self._ssl_ctx: return self._ssl_ctx self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) if self._tls_ca_data: - self._ssl_ctx.load_verify_locations( - cadata=self._tls_ca_data - ) + self._ssl_ctx.load_verify_locations(cadata=self._tls_ca_data) else: self._ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) if platform.IS_WINDOWS: import certifi + self._ssl_ctx.load_verify_locations(cafile=certifi.where()) tls_security = self.tls_security @@ -343,7 +344,7 @@ def ssl_ctx(self): else: self._ssl_ctx.verify_mode = ssl.CERT_NONE - self._ssl_ctx.set_alpn_protocols(['edgedb-binary']) + self._ssl_ctx.set_alpn_protocols(["edgedb-binary"]) return self._ssl_ctx @@ -357,18 +358,18 @@ def wait_until_available(self): def _validate_host(host): - if '/' in host: - raise ValueError('unix socket paths not supported') - if host == '' or ',' in host: + if "/" in host: + raise ValueError("unix socket paths not supported") + if host == "" or "," in host: raise ValueError(f'invalid host: "{host}"') return host def _prepare_host_for_dsn(host): host = _validate_host(host) - if ':' in host: + if ":" in host: # IPv6 - host = f'[{host}]' + host = f"[{host}]" return host @@ -379,25 +380,25 @@ def _validate_port(port): if not isinstance(port, int): raise ValueError() except Exception: - raise ValueError(f'invalid port: {port}, not an integer') + raise ValueError(f"invalid port: {port}, not an integer") if port < 1 or port > 65535: - raise ValueError(f'invalid port: {port}, must be between 1 and 65535') + raise ValueError(f"invalid port: {port}, must be between 1 and 65535") return port def _validate_database(database): - if database == '': - raise ValueError(f'invalid database name: {database}') + if database == "": + raise ValueError(f"invalid database name: {database}") return database def _validate_user(user): - if user == '': - raise ValueError(f'invalid user name: {user}') + if user == "": + raise ValueError(f"invalid user name: {user}") return user -def _pop_iso_unit(rgex: re.Pattern, string: str) -> typing.Tuple[float, str]: +def _pop_iso_unit(rgex: re.Pattern, string: str) -> tuple[float, str]: s = string total = 0 match = rgex.search(string) @@ -408,7 +409,7 @@ def _pop_iso_unit(rgex: re.Pattern, string: str) -> typing.Tuple[float, str]: return (total, s) -def _parse_iso_duration(string: str) -> typing.Union[float, int]: +def _parse_iso_duration(string: str) -> float | int: if not string.startswith("PT"): raise ValueError(f"invalid duration {string!r}") @@ -423,19 +424,19 @@ def _parse_iso_duration(string: str) -> typing.Union[float, int]: seconds, time = _pop_iso_unit(ISO_SECONDS_RE, time) if time: - raise ValueError(f'invalid duration {string!r}') + raise ValueError(f"invalid duration {string!r}") return 3600 * hours + 60 * minutes + seconds def _remove_white_space(s: str) -> str: - return ''.join(c for c in s if not c.isspace()) + return "".join(c for c in s if not c.isspace()) def _pop_human_duration_unit( rgex: re.Pattern, string: str, -) -> typing.Tuple[float, bool, str]: +) -> tuple[float, bool, str]: match = rgex.search(string) if not match: return 0, False, string @@ -443,10 +444,10 @@ def _pop_human_duration_unit( number = 0 if match.group(1): literal = _remove_white_space(match.group(1)) - if literal.endswith('.'): + if literal.endswith("."): return 0, False, string - if literal.startswith('-.'): + if literal.startswith("-."): return 0, False, string number = float(literal) @@ -478,13 +479,13 @@ def _parse_human_duration(string: str) -> float: found |= f if s.strip() or not found: - raise ValueError(f'invalid duration {string!r}') + raise ValueError(f"invalid duration {string!r}") return 3600 * hour + 60 * minute + second + 0.001 * ms + 0.000001 * us def _parse_duration_str(string: str) -> float: - if string.startswith('PT'): + if string.startswith("PT"): return _parse_iso_duration(string) return _parse_human_duration(string) @@ -501,13 +502,13 @@ def _validate_wait_until_available(wait_until_available): def _validate_server_settings(server_settings): if ( - not isinstance(server_settings, dict) or - not all(isinstance(k, str) for k in server_settings) or - not all(isinstance(v, str) for v in server_settings.values()) + not isinstance(server_settings, dict) + or not all(isinstance(k, str) for k in server_settings) + or not all(isinstance(v, str) for v in server_settings.values()) ): raise ValueError( - 'server_settings is expected to be None or ' - 'a Dict[str, str]') + "server_settings is expected to be None or " "a Dict[str, str]" + ) def _parse_connect_dsn_and_args( @@ -536,148 +537,172 @@ def _parse_connect_dsn_and_args( has_compound_options = _resolve_config_options( resolved_config, - 'Cannot have more than one of the following connection options: ' + "Cannot have more than one of the following connection options: " + '"dsn", "credentials", "credentials_file" or "host"/"port"', dsn=(dsn, '"dsn" option') if dsn is not None else None, instance_name=( (instance_name, '"dsn" option (parsed as instance name)') - if instance_name is not None else None + if instance_name is not None + else None ), credentials=( (credentials, '"credentials" option') - if credentials is not None else None + if credentials is not None + else None ), credentials_file=( (credentials_file, '"credentials_file" option') - if credentials_file is not None else None + if credentials_file is not None + else None ), host=(host, '"host" option') if host is not None else None, port=(port, '"port" option') if port is not None else None, database=( - (database, '"database" option') - if database is not None else None + (database, '"database" option') if database is not None else None ), user=(user, '"user" option') if user is not None else None, password=( - (password, '"password" option') - if password is not None else None + (password, '"password" option') if password is not None else None ), secret_key=( (secret_key, '"secret_key" option') - if secret_key is not None else None - ), - tls_ca=( - (tls_ca, '"tls_ca" option') - if tls_ca is not None else None + if secret_key is not None + else None ), + tls_ca=((tls_ca, '"tls_ca" option') if tls_ca is not None else None), tls_ca_file=( (tls_ca_file, '"tls_ca_file" option') - if tls_ca_file is not None else None + if tls_ca_file is not None + else None ), tls_security=( (tls_security, '"tls_security" option') - if tls_security is not None else None + if tls_security is not None + else None ), server_settings=( (server_settings, '"server_settings" option') - if server_settings is not None else None + if server_settings is not None + else None ), wait_until_available=( (wait_until_available, '"wait_until_available" option') - if wait_until_available is not None else None + if wait_until_available is not None + else None ), ) if has_compound_options is False: env_port = os.getenv("EDGEDB_PORT") if ( - resolved_config._port is None and - env_port and env_port.startswith('tcp://') + resolved_config._port is None + and env_port + and env_port.startswith("tcp://") ): # EDGEDB_PORT is set by 'docker --link' so ignore and warn - warnings.warn('EDGEDB_PORT in "tcp://host:port" format, ' + - 'so will be ignored') + warnings.warn( + 'EDGEDB_PORT in "tcp://host:port" format, ' + + "so will be ignored" + ) env_port = None - env_dsn = os.getenv('EDGEDB_DSN') - env_instance = os.getenv('EDGEDB_INSTANCE') - env_credentials_file = os.getenv('EDGEDB_CREDENTIALS_FILE') - env_host = os.getenv('EDGEDB_HOST') - env_database = os.getenv('EDGEDB_DATABASE') - env_user = os.getenv('EDGEDB_USER') - env_password = os.getenv('EDGEDB_PASSWORD') - env_secret_key = os.getenv('EDGEDB_SECRET_KEY') - env_tls_ca = os.getenv('EDGEDB_TLS_CA') - env_tls_ca_file = os.getenv('EDGEDB_TLS_CA_FILE') - env_tls_security = os.getenv('EDGEDB_CLIENT_TLS_SECURITY') - env_wait_until_available = os.getenv('EDGEDB_WAIT_UNTIL_AVAILABLE') - cloud_profile = os.getenv('EDGEDB_CLOUD_PROFILE') + env_dsn = os.getenv("EDGEDB_DSN") + env_instance = os.getenv("EDGEDB_INSTANCE") + env_credentials_file = os.getenv("EDGEDB_CREDENTIALS_FILE") + env_host = os.getenv("EDGEDB_HOST") + env_database = os.getenv("EDGEDB_DATABASE") + env_user = os.getenv("EDGEDB_USER") + env_password = os.getenv("EDGEDB_PASSWORD") + env_secret_key = os.getenv("EDGEDB_SECRET_KEY") + env_tls_ca = os.getenv("EDGEDB_TLS_CA") + env_tls_ca_file = os.getenv("EDGEDB_TLS_CA_FILE") + env_tls_security = os.getenv("EDGEDB_CLIENT_TLS_SECURITY") + env_wait_until_available = os.getenv("EDGEDB_WAIT_UNTIL_AVAILABLE") + cloud_profile = os.getenv("EDGEDB_CLOUD_PROFILE") has_compound_options = _resolve_config_options( resolved_config, - 'Cannot have more than one of the following connection ' + "Cannot have more than one of the following connection " + 'environment variables: "EDGEDB_DSN", "EDGEDB_INSTANCE", ' + '"EDGEDB_CREDENTIALS_FILE" or "EDGEDB_HOST"/"EDGEDB_PORT"', dsn=( (env_dsn, '"EDGEDB_DSN" environment variable') - if env_dsn is not None else None + if env_dsn is not None + else None ), instance_name=( (env_instance, '"EDGEDB_INSTANCE" environment variable') - if env_instance is not None else None + if env_instance is not None + else None ), credentials_file=( - (env_credentials_file, - '"EDGEDB_CREDENTIALS_FILE" environment variable') - if env_credentials_file is not None else None + ( + env_credentials_file, + '"EDGEDB_CREDENTIALS_FILE" environment variable', + ) + if env_credentials_file is not None + else None ), host=( (env_host, '"EDGEDB_HOST" environment variable') - if env_host is not None else None + if env_host is not None + else None ), port=( (env_port, '"EDGEDB_PORT" environment variable') - if env_port is not None else None + if env_port is not None + else None ), database=( (env_database, '"EDGEDB_DATABASE" environment variable') - if env_database is not None else None + if env_database is not None + else None ), user=( (env_user, '"EDGEDB_USER" environment variable') - if env_user is not None else None + if env_user is not None + else None ), password=( (env_password, '"EDGEDB_PASSWORD" environment variable') - if env_password is not None else None + if env_password is not None + else None ), secret_key=( (env_secret_key, '"EDGEDB_SECRET_KEY" environment variable') - if env_secret_key is not None else None + if env_secret_key is not None + else None ), tls_ca=( (env_tls_ca, '"EDGEDB_TLS_CA" environment variable') - if env_tls_ca is not None else None + if env_tls_ca is not None + else None ), tls_ca_file=( (env_tls_ca_file, '"EDGEDB_TLS_CA_FILE" environment variable') - if env_tls_ca_file is not None else None + if env_tls_ca_file is not None + else None ), tls_security=( - (env_tls_security, - '"EDGEDB_CLIENT_TLS_SECURITY" environment variable') - if env_tls_security is not None else None + ( + env_tls_security, + '"EDGEDB_CLIENT_TLS_SECURITY" environment variable', + ) + if env_tls_security is not None + else None ), wait_until_available=( ( env_wait_until_available, - '"EDGEDB_WAIT_UNTIL_AVAILABLE" environment variable' - ) if env_wait_until_available is not None else None + '"EDGEDB_WAIT_UNTIL_AVAILABLE" environment variable', + ) + if env_wait_until_available is not None + else None ), cloud_profile=( - (cloud_profile, - '"EDGEDB_CLOUD_PROFILE" environment variable') - if cloud_profile is not None else None + (cloud_profile, '"EDGEDB_CLOUD_PROFILE" environment variable') + if cloud_profile is not None + else None ), ) @@ -685,39 +710,38 @@ def _parse_connect_dsn_and_args( dir = find_edgedb_project_dir() stash_dir = _stash_path(dir) if os.path.exists(stash_dir): - with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f: + with open(os.path.join(stash_dir, "instance-name")) as f: instance_name = f.read().strip() - cloud_profile_file = os.path.join(stash_dir, 'cloud-profile') + cloud_profile_file = os.path.join(stash_dir, "cloud-profile") if os.path.exists(cloud_profile_file): - with open(cloud_profile_file, 'rt') as f: + with open(cloud_profile_file) as f: cloud_profile = f.read().strip() else: cloud_profile = None _resolve_config_options( resolved_config, - '', + "", instance_name=( instance_name, - f'project linked instance ("{instance_name}")' + f'project linked instance ("{instance_name}")', ), cloud_profile=( cloud_profile, - f'project defined cloud profile ("{cloud_profile}")' + f'project defined cloud profile ("{cloud_profile}")', ), ) else: raise errors.ClientConnectionError( - f'Found `edgedb.toml` but the project is not initialized. ' - f'Run `edgedb project init`.' + "Found `edgedb.toml` but the project is not initialized. " + "Run `edgedb project init`." ) return resolved_config def _parse_dsn_into_config( - resolved_config: ResolvedConnectConfig, - dsn: typing.Tuple[str, str] + resolved_config: ResolvedConnectConfig, dsn: tuple[str, str] ): dsn_str, source = dsn @@ -731,66 +755,69 @@ def _parse_dsn_into_config( user = parsed.username password = parsed.password except Exception as e: - raise ValueError(f'invalid DSN or instance name: {str(e)}') + raise ValueError(f"invalid DSN or instance name: {e!s}") - if parsed.scheme != 'edgedb': + if parsed.scheme != "edgedb": raise ValueError( - f'invalid DSN: scheme is expected to be ' - f'"edgedb", got {parsed.scheme!r}') + f"invalid DSN: scheme is expected to be " + f'"edgedb", got {parsed.scheme!r}' + ) query = ( urllib.parse.parse_qs(parsed.query, keep_blank_values=True) - if parsed.query != '' + if parsed.query != "" else {} ) for key, val in query.items(): if isinstance(val, list): if len(val) > 1: raise ValueError( - f'invalid DSN: duplicate query parameter {key}') + f"invalid DSN: duplicate query parameter {key}" + ) query[key] = val[-1] def handle_dsn_part( - paramName, value, currentValue, setter, - formatter=lambda val: val + paramName, value, currentValue, setter, formatter=lambda val: val ): param_values = [ - (value if value != '' else None), + (value if value != "" else None), query.get(paramName), - query.get(paramName + '_env'), - query.get(paramName + '_file') + query.get(paramName + "_env"), + query.get(paramName + "_file"), ] if len([p for p in param_values if p is not None]) > 1: raise ValueError( - f'invalid DSN: more than one of ' + - f'{(paramName + ", ") if value else ""}' + - f'?{paramName}=, ?{paramName}_env=, ?{paramName}_file= ' + - f'was specified' + "invalid DSN: more than one of " + + f'{(paramName + ", ") if value else ""}' + + f"?{paramName}=, ?{paramName}_env=, ?{paramName}_file= " + + "was specified" ) if currentValue is None: param = ( - value if (value is not None and value != '') + value + if (value is not None and value != "") else query.get(paramName) ) paramSource = source if param is None: - env = query.get(paramName + '_env') + env = query.get(paramName + "_env") if env is not None: param = os.getenv(env) if param is None: raise ValueError( - f'{paramName}_env environment variable "{env}" ' + - f'doesn\'t exist') - paramSource = paramSource + f' ({paramName}_env: {env})' + f'{paramName}_env environment variable "{env}" ' + + "doesn't exist" + ) + paramSource = paramSource + f" ({paramName}_env: {env})" if param is None: - filename = query.get(paramName + '_file') + filename = query.get(paramName + "_file") if filename is not None: with open(filename) as f: param = f.read() paramSource = ( - paramSource + f' ({paramName}_file: {filename})' + paramSource + f" ({paramName}_file: {filename})" ) param = formatter(param) if param is not None else None @@ -798,55 +825,65 @@ def handle_dsn_part( setter(param, paramSource) query.pop(paramName, None) - query.pop(paramName + '_env', None) - query.pop(paramName + '_file', None) + query.pop(paramName + "_env", None) + query.pop(paramName + "_file", None) handle_dsn_part( - 'host', host, resolved_config._host, resolved_config.set_host + "host", host, resolved_config._host, resolved_config.set_host ) handle_dsn_part( - 'port', port, resolved_config._port, resolved_config.set_port + "port", port, resolved_config._port, resolved_config.set_port ) def strip_leading_slash(str): - return str[1:] if str.startswith('/') else str + return str[1:] if str.startswith("/") else str handle_dsn_part( - 'database', strip_leading_slash(database), - resolved_config._database, resolved_config.set_database, - strip_leading_slash + "database", + strip_leading_slash(database), + resolved_config._database, + resolved_config.set_database, + strip_leading_slash, ) handle_dsn_part( - 'user', user, resolved_config._user, resolved_config.set_user + "user", user, resolved_config._user, resolved_config.set_user ) handle_dsn_part( - 'password', password, - resolved_config._password, resolved_config.set_password + "password", + password, + resolved_config._password, + resolved_config.set_password, ) handle_dsn_part( - 'secret_key', None, - resolved_config._secret_key, resolved_config.set_secret_key + "secret_key", + None, + resolved_config._secret_key, + resolved_config.set_secret_key, ) handle_dsn_part( - 'tls_ca_file', None, - resolved_config._tls_ca_data, resolved_config.set_tls_ca_file + "tls_ca_file", + None, + resolved_config._tls_ca_data, + resolved_config.set_tls_ca_file, ) handle_dsn_part( - 'tls_security', None, + "tls_security", + None, resolved_config._tls_security, - resolved_config.set_tls_security + resolved_config.set_tls_security, ) handle_dsn_part( - 'wait_until_available', None, + "wait_until_available", + None, resolved_config._wait_until_available, - resolved_config.set_wait_until_available + resolved_config.set_wait_until_available, ) resolved_config.add_server_settings(query) @@ -855,9 +892,9 @@ def strip_leading_slash(str): def _jwt_base64_decode(payload): remainder = len(payload) % 4 if remainder == 2: - payload += '==' + payload += "==" elif remainder == 3: - payload += '=' + payload += "=" elif remainder != 0: raise errors.ClientConnectionError("Invalid secret key") payload = base64.urlsafe_b64decode(payload.encode("utf-8")) @@ -887,7 +924,7 @@ def _parse_cloud_instance_name_into_config( profile = resolved_config._cloud_profile profile_src = resolved_config._cloud_profile_source path = config_dir / "cloud-credentials" / f"{profile}.json" - with open(path, "rt") as f: + with open(path) as f: secret_key = json.load(f)["secret_key"] except Exception: raise errors.ClientConnectionError( @@ -903,7 +940,7 @@ def _parse_cloud_instance_name_into_config( raise except Exception: raise errors.ClientConnectionError("Invalid secret key") - payload = f"{org_slug}/{instance_name}".encode("utf-8") + payload = f"{org_slug}/{instance_name}".encode() dns_bucket = binascii.crc_hqx(payload, 0) % 100 host = f"{label}.c-{dns_bucket:02d}.i.{dns_zone}" resolved_config.set_host(host, source) @@ -941,7 +978,8 @@ def _resolve_config_options( if tls_ca_file is not None: if tls_ca is not None: raise errors.ClientConnectionError( - f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive") + f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive" + ) resolved_config.set_tls_ca_file(*tls_ca_file) if tls_ca is not None: resolved_config.set_tls_ca_data(*tls_ca) @@ -952,7 +990,7 @@ def _resolve_config_options( if wait_until_available is not None: resolved_config.set_wait_until_available(*wait_until_available) if cloud_profile is not None: - resolved_config._set_param('cloud_profile', *cloud_profile) + resolved_config._set_param("cloud_profile", *cloud_profile) compound_params = [ dsn, @@ -972,9 +1010,9 @@ def _resolve_config_options( resolved_config.set_port(*port) if dsn is None: dsn = ( - 'edgedb://' + - (_prepare_host_for_dsn(host[0]) if host else ''), - host[1] if host is not None else port[1] + "edgedb://" + + (_prepare_host_for_dsn(host[0]) if host else ""), + host[1] if host is not None else port[1], ) _parse_dsn_into_config(resolved_config, dsn) else: @@ -985,7 +1023,7 @@ def _resolve_config_options( try: cred_data = json.loads(credentials[0]) except ValueError as e: - raise RuntimeError(f"cannot read credentials") from e + raise RuntimeError("cannot read credentials") from e else: creds = cred_utils.validate_credentials(cred_data) source = "credentials" @@ -1007,16 +1045,13 @@ def _resolve_config_options( ) return True - resolved_config.set_host(creds.get('host'), source) - resolved_config.set_port(creds.get('port'), source) - resolved_config.set_database(creds.get('database'), source) - resolved_config.set_user(creds.get('user'), source) - resolved_config.set_password(creds.get('password'), source) - resolved_config.set_tls_ca_data(creds.get('tls_ca'), source) - resolved_config.set_tls_security( - creds.get('tls_security'), - source - ) + resolved_config.set_host(creds.get("host"), source) + resolved_config.set_port(creds.get("port"), source) + resolved_config.set_database(creds.get("database"), source) + resolved_config.set_user(creds.get("user"), source) + resolved_config.set_password(creds.get("password"), source) + resolved_config.set_tls_ca_data(creds.get("tls_ca"), source) + resolved_config.set_tls_security(creds.get("tls_security"), source) return True @@ -1029,21 +1064,21 @@ def find_edgedb_project_dir(): dev = os.stat(dir).st_dev while True: - toml = os.path.join(dir, 'edgedb.toml') + toml = os.path.join(dir, "edgedb.toml") if not os.path.isfile(toml): parent = os.path.dirname(dir) if parent == dir: raise errors.ClientConnectionError( - f'no `edgedb.toml` found and ' - f'no connection options specified' + "no `edgedb.toml` found and " + "no connection options specified" ) parent_dev = os.stat(parent).st_dev if parent_dev != dev: raise errors.ClientConnectionError( - f'no `edgedb.toml` found and ' - f'no connection options specified' - f'(stopped searching for `edgedb.toml` at file system' - f'boundary {dir!r})' + f"no `edgedb.toml` found and " + f"no connection options specified" + f"(stopped searching for `edgedb.toml` at file system" + f"boundary {dir!r})" ) dir = parent dev = parent_dev @@ -1069,8 +1104,7 @@ def parse_connect_arguments( command_timeout, wait_until_available, server_settings, -) -> typing.Tuple[ResolvedConnectConfig, ClientConfiguration]: - +) -> tuple[ResolvedConnectConfig, ClientConfiguration]: if command_timeout is not None: try: if isinstance(command_timeout, bool): @@ -1080,9 +1114,11 @@ def parse_connect_arguments( raise ValueError except ValueError: raise ValueError( - 'invalid command_timeout value: ' - 'expected greater than 0 float (got {!r})'.format( - command_timeout)) from None + "invalid command_timeout value: " + "expected greater than 0 float (got {!r})".format( + command_timeout + ) + ) from None connect_config = _parse_connect_dsn_and_args( dsn=dsn, @@ -1111,7 +1147,7 @@ def parse_connect_arguments( def check_alpn_protocol(ssl_obj): - if ssl_obj.selected_alpn_protocol() != 'edgedb-binary': + if ssl_obj.selected_alpn_protocol() != "edgedb-binary": raise errors.ClientConnectionFailedError( "The server doesn't support the edgedb-binary protocol." ) @@ -1120,24 +1156,24 @@ def check_alpn_protocol(ssl_obj): def render_client_no_connection_error(prefix, addr, attempts, duration): if isinstance(addr, str): msg = ( - f'{prefix}' - f'\n\tAfter {attempts} attempts in {duration:.1f} sec' - f'\n\tIs the server running locally and accepting ' - f'\n\tconnections on Unix domain socket {addr!r}?' + f"{prefix}" + f"\n\tAfter {attempts} attempts in {duration:.1f} sec" + f"\n\tIs the server running locally and accepting " + f"\n\tconnections on Unix domain socket {addr!r}?" ) else: msg = ( - f'{prefix}' - f'\n\tAfter {attempts} attempts in {duration:.1f} sec' - f'\n\tIs the server running on host {addr[0]!r} ' - f'and accepting ' - f'\n\tTCP/IP connections on port {addr[1]}?' + f"{prefix}" + f"\n\tAfter {attempts} attempts in {duration:.1f} sec" + f"\n\tIs the server running on host {addr[0]!r} " + f"and accepting " + f"\n\tTCP/IP connections on port {addr[1]}?" ) return msg def _extract_errno(s): - """Extract multiple errnos from error string + """Extract multiple errnos from error string. When we connect to a host that has multiple underlying IP addresses, say ``localhost`` having ``::1`` and ``127.0.0.1``, we get @@ -1159,7 +1195,7 @@ def wrap_error(e): errnos = [e.errno] if errnos: - is_temp = any((code in TEMPORARY_ERROR_CODES for code in errnos)) + is_temp = any(code in TEMPORARY_ERROR_CODES for code in errnos) else: is_temp = isinstance(e, TEMPORARY_ERRORS) diff --git a/edgedb/connresource.py b/edgedb/connresource.py index e999a0fe..6da3ec7c 100644 --- a/edgedb/connresource.py +++ b/edgedb/connresource.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import functools @@ -34,7 +34,6 @@ def _check(self, *args, **kwargs): class ConnectionResource: - def __init__(self, connection): self._connection = connection self._con_release_ctr = connection._pool_release_ctr @@ -43,12 +42,15 @@ def _check_conn_validity(self, meth_name): con_release_ctr = self._connection._pool_release_ctr if con_release_ctr != self._con_release_ctr: raise errors.InterfaceError( - 'cannot call {}.{}(): ' - 'the underlying connection has been released back ' - 'to the pool'.format(self.__class__.__name__, meth_name)) + "cannot call {}.{}(): " + "the underlying connection has been released back " + "to the pool".format(self.__class__.__name__, meth_name) + ) if self._connection.is_closed(): raise errors.InterfaceError( - 'cannot call {}.{}(): ' - 'the underlying connection is closed'.format( - self.__class__.__name__, meth_name)) + "cannot call {}.{}(): " + "the underlying connection is closed".format( + self.__class__.__name__, meth_name + ) + ) diff --git a/edgedb/credentials.py b/edgedb/credentials.py index aa6266bc..d6ebe40c 100644 --- a/edgedb/credentials.py +++ b/edgedb/credentials.py @@ -1,7 +1,9 @@ +from __future__ import annotations + +import json import os import pathlib import typing -import json from . import platform @@ -12,11 +14,11 @@ class RequiredCredentials(typing.TypedDict, total=True): class Credentials(RequiredCredentials, total=False): - host: typing.Optional[str] - password: typing.Optional[str] - database: typing.Optional[str] - tls_ca: typing.Optional[str] - tls_security: typing.Optional[str] + host: str | None + password: str | None + database: str | None + tls_ca: str | None + tls_security: str | None def get_credentials_path(instance_name: str) -> pathlib.Path: @@ -25,23 +27,21 @@ def get_credentials_path(instance_name: str) -> pathlib.Path: def read_credentials(path: os.PathLike) -> Credentials: try: - with open(path, encoding='utf-8') as f: + with open(path, encoding="utf-8") as f: credentials = json.load(f) return validate_credentials(credentials) except Exception as e: - raise RuntimeError( - f"cannot read credentials at {path}" - ) from e + raise RuntimeError(f"cannot read credentials at {path}") from e def validate_credentials(data: dict) -> Credentials: - port = data.get('port') + port = data.get("port") if port is None: port = 5656 if not isinstance(port, int) or port < 1 or port > 65535: raise ValueError("invalid `port` value") - user = data.get('user') + user = data.get("user") if user is None: raise ValueError("`user` key is required") if not isinstance(user, str): @@ -52,53 +52,56 @@ def validate_credentials(data: dict) -> Credentials: "port": port, } - host = data.get('host') + host = data.get("host") if host is not None: if not isinstance(host, str): raise ValueError("`host` must be a string") - result['host'] = host + result["host"] = host - database = data.get('database') + database = data.get("database") if database is not None: if not isinstance(database, str): raise ValueError("`database` must be a string") - result['database'] = database + result["database"] = database - password = data.get('password') + password = data.get("password") if password is not None: if not isinstance(password, str): raise ValueError("`password` must be a string") - result['password'] = password + result["password"] = password - ca = data.get('tls_ca') + ca = data.get("tls_ca") if ca is not None: if not isinstance(ca, str): raise ValueError("`tls_ca` must be a string") - result['tls_ca'] = ca + result["tls_ca"] = ca - cert_data = data.get('tls_cert_data') + cert_data = data.get("tls_cert_data") if cert_data is not None: if not isinstance(cert_data, str): raise ValueError("`tls_cert_data` must be a string") if ca is not None and ca != cert_data: raise ValueError( - f"tls_ca and tls_cert_data are both set and disagree") - result['tls_ca'] = cert_data + "tls_ca and tls_cert_data are both set and disagree" + ) + result["tls_ca"] = cert_data - verify = data.get('tls_verify_hostname') + verify = data.get("tls_verify_hostname") if verify is not None: if not isinstance(verify, bool): raise ValueError("`tls_verify_hostname` must be a bool") - result['tls_security'] = 'strict' if verify else 'no_host_verification' + result["tls_security"] = "strict" if verify else "no_host_verification" - tls_security = data.get('tls_security') + tls_security = data.get("tls_security") if tls_security is not None: if not isinstance(tls_security, str): raise ValueError("`tls_security` must be a string") - result['tls_security'] = tls_security + result["tls_security"] = tls_security - missmatch = ValueError(f"tls_verify_hostname={verify} and " - f"tls_security={tls_security} are incompatible") + missmatch = ValueError( + f"tls_verify_hostname={verify} and " + f"tls_security={tls_security} are incompatible" + ) if tls_security == "strict" and verify is False: raise missmatch if tls_security in {"no_host_verification", "insecure"} and verify is True: diff --git a/edgedb/datatypes/range.py b/edgedb/datatypes/range.py index eaeb4bcb..b0ed3e81 100644 --- a/edgedb/datatypes/range.py +++ b/edgedb/datatypes/range.py @@ -15,21 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from typing import Generic, Optional, TypeVar - +from typing import Generic, TypeVar T = TypeVar("T") class Range(Generic[T]): - __slots__ = ("_lower", "_upper", "_inc_lower", "_inc_upper", "_empty") def __init__( self, - lower: Optional[T] = None, - upper: Optional[T] = None, + lower: T | None = None, + upper: T | None = None, *, inc_lower: bool = True, inc_upper: bool = False, @@ -38,13 +37,10 @@ def __init__( self._empty = empty if empty: - if ( - lower != upper - or lower is not None and inc_upper and inc_lower - ): + if lower != upper or lower is not None and inc_upper and inc_lower: raise ValueError( "conflicting arguments in range constructor: " - "\"empty\" is `true` while the specified bounds " + '"empty" is `true` while the specified bounds ' "suggest otherwise" ) @@ -57,7 +53,7 @@ def __init__( self._inc_upper = upper is not None and inc_upper @property - def lower(self) -> Optional[T]: + def lower(self) -> T | None: return self._lower @property @@ -65,7 +61,7 @@ def inc_lower(self) -> bool: return self._inc_lower @property - def upper(self) -> Optional[T]: + def upper(self) -> T | None: return self._upper @property @@ -87,7 +83,7 @@ def __eq__(self, other): self._upper, self._inc_lower, self._inc_upper, - self._empty + self._empty, ) == ( other._lower, other._upper, @@ -97,13 +93,15 @@ def __eq__(self, other): ) def __hash__(self) -> int: - return hash(( - self._lower, - self._upper, - self._inc_lower, - self._inc_upper, - self._empty, - )) + return hash( + ( + self._lower, + self._upper, + self._inc_lower, + self._inc_upper, + self._empty, + ) + ) def __str__(self) -> str: if self._empty: diff --git a/edgedb/describe.py b/edgedb/describe.py index 05c16398..b72283b8 100644 --- a/edgedb/describe.py +++ b/edgedb/describe.py @@ -15,10 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import dataclasses -import typing import uuid from . import enums @@ -27,7 +26,7 @@ @dataclasses.dataclass(frozen=True) class AnyType: desc_id: uuid.UUID - name: typing.Optional[str] + name: str | None @dataclasses.dataclass(frozen=True) @@ -50,7 +49,7 @@ class SetType(SequenceType): @dataclasses.dataclass(frozen=True) class ObjectType(AnyType): - elements: typing.Dict[str, Element] + elements: dict[str, Element] @dataclasses.dataclass(frozen=True) @@ -65,12 +64,12 @@ class ScalarType(AnyType): @dataclasses.dataclass(frozen=True) class TupleType(AnyType): - element_types: typing.Tuple[AnyType] + element_types: tuple[AnyType] @dataclasses.dataclass(frozen=True) class NamedTupleType(AnyType): - element_types: typing.Dict[str, AnyType] + element_types: dict[str, AnyType] @dataclasses.dataclass(frozen=True) @@ -80,7 +79,7 @@ class ArrayType(SequenceType): @dataclasses.dataclass(frozen=True) class EnumType(AnyType): - members: typing.Tuple[str] + members: tuple[str] @dataclasses.dataclass(frozen=True) diff --git a/edgedb/enums.py b/edgedb/enums.py index 6312f596..a9bcdbdb 100644 --- a/edgedb/enums.py +++ b/edgedb/enums.py @@ -15,30 +15,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import enum class Capability(enum.IntFlag): + NONE = 0 + MODIFICATIONS = 1 << 0 + SESSION_CONFIG = 1 << 1 + TRANSACTION = 1 << 2 + DDL = 1 << 3 + PERSISTENT_CONFIG = 1 << 4 - NONE = 0 # noqa - MODIFICATIONS = 1 << 0 # noqa - SESSION_CONFIG = 1 << 1 # noqa - TRANSACTION = 1 << 2 # noqa - DDL = 1 << 3 # noqa - PERSISTENT_CONFIG = 1 << 4 # noqa - - ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa - EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa - LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa + ALL = 0xFFFF_FFFF_FFFF_FFFF + EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG + LEGACY_EXECUTE = ALL & ~TRANSACTION class CompilationFlag(enum.IntFlag): - - INJECT_OUTPUT_TYPE_IDS = 1 << 0 # noqa - INJECT_OUTPUT_TYPE_NAMES = 1 << 1 # noqa - INJECT_OUTPUT_OBJECT_IDS = 1 << 2 # noqa + INJECT_OUTPUT_TYPE_IDS = 1 << 0 + INJECT_OUTPUT_TYPE_NAMES = 1 << 1 + INJECT_OUTPUT_OBJECT_IDS = 1 << 2 class Cardinality(enum.Enum): @@ -46,19 +44,19 @@ class Cardinality(enum.Enum): # * the query is a command like CONFIGURE that # does not return any data; # * the query is composed of multiple queries. - NO_RESULT = 0x6e + NO_RESULT = 0x6E # Cardinality is 1 or 0 - AT_MOST_ONE = 0x6f + AT_MOST_ONE = 0x6F # Cardinality is 1 ONE = 0x41 # Cardinality is >= 0 - MANY = 0x6d + MANY = 0x6D # Cardinality is >= 1 - AT_LEAST_ONE = 0x4d + AT_LEAST_ONE = 0x4D def is_single(self) -> bool: return self in {Cardinality.AT_MOST_ONE, Cardinality.ONE} @@ -68,7 +66,6 @@ def is_multi(self) -> bool: class ElementKind(enum.Enum): - - LINK = 1 # noqa - PROPERTY = 2 # noqa - LINK_PROPERTY = 3 # noqa + LINK = 1 + PROPERTY = 2 + LINK_PROPERTY = 3 diff --git a/edgedb/errors/__init__.py b/edgedb/errors/__init__.py index cfe55e1d..e5ced0ec 100644 --- a/edgedb/errors/__init__.py +++ b/edgedb/errors/__init__.py @@ -14,103 +14,103 @@ __all__ = _base.__all__ + ( # type: ignore - 'InternalServerError', - 'UnsupportedFeatureError', - 'ProtocolError', - 'BinaryProtocolError', - 'UnsupportedProtocolVersionError', - 'TypeSpecNotFoundError', - 'UnexpectedMessageError', - 'InputDataError', - 'ParameterTypeMismatchError', - 'StateMismatchError', - 'ResultCardinalityMismatchError', - 'CapabilityError', - 'UnsupportedCapabilityError', - 'DisabledCapabilityError', - 'QueryError', - 'InvalidSyntaxError', - 'EdgeQLSyntaxError', - 'SchemaSyntaxError', - 'GraphQLSyntaxError', - 'InvalidTypeError', - 'InvalidTargetError', - 'InvalidLinkTargetError', - 'InvalidPropertyTargetError', - 'InvalidReferenceError', - 'UnknownModuleError', - 'UnknownLinkError', - 'UnknownPropertyError', - 'UnknownUserError', - 'UnknownDatabaseError', - 'UnknownParameterError', - 'SchemaError', - 'SchemaDefinitionError', - 'InvalidDefinitionError', - 'InvalidModuleDefinitionError', - 'InvalidLinkDefinitionError', - 'InvalidPropertyDefinitionError', - 'InvalidUserDefinitionError', - 'InvalidDatabaseDefinitionError', - 'InvalidOperatorDefinitionError', - 'InvalidAliasDefinitionError', - 'InvalidFunctionDefinitionError', - 'InvalidConstraintDefinitionError', - 'InvalidCastDefinitionError', - 'DuplicateDefinitionError', - 'DuplicateModuleDefinitionError', - 'DuplicateLinkDefinitionError', - 'DuplicatePropertyDefinitionError', - 'DuplicateUserDefinitionError', - 'DuplicateDatabaseDefinitionError', - 'DuplicateOperatorDefinitionError', - 'DuplicateViewDefinitionError', - 'DuplicateFunctionDefinitionError', - 'DuplicateConstraintDefinitionError', - 'DuplicateCastDefinitionError', - 'DuplicateMigrationError', - 'SessionTimeoutError', - 'IdleSessionTimeoutError', - 'QueryTimeoutError', - 'TransactionTimeoutError', - 'IdleTransactionTimeoutError', - 'ExecutionError', - 'InvalidValueError', - 'DivisionByZeroError', - 'NumericOutOfRangeError', - 'AccessPolicyError', - 'QueryAssertionError', - 'IntegrityError', - 'ConstraintViolationError', - 'CardinalityViolationError', - 'MissingRequiredError', - 'TransactionError', - 'TransactionConflictError', - 'TransactionSerializationError', - 'TransactionDeadlockError', - 'WatchError', - 'ConfigurationError', - 'AccessError', - 'AuthenticationError', - 'AvailabilityError', - 'BackendUnavailableError', - 'BackendError', - 'UnsupportedBackendFeatureError', - 'LogMessage', - 'WarningMessage', - 'ClientError', - 'ClientConnectionError', - 'ClientConnectionFailedError', - 'ClientConnectionFailedTemporarilyError', - 'ClientConnectionTimeoutError', - 'ClientConnectionClosedError', - 'InterfaceError', - 'QueryArgumentError', - 'MissingArgumentError', - 'UnknownArgumentError', - 'InvalidArgumentError', - 'NoDataError', - 'InternalClientError', + "InternalServerError", + "UnsupportedFeatureError", + "ProtocolError", + "BinaryProtocolError", + "UnsupportedProtocolVersionError", + "TypeSpecNotFoundError", + "UnexpectedMessageError", + "InputDataError", + "ParameterTypeMismatchError", + "StateMismatchError", + "ResultCardinalityMismatchError", + "CapabilityError", + "UnsupportedCapabilityError", + "DisabledCapabilityError", + "QueryError", + "InvalidSyntaxError", + "EdgeQLSyntaxError", + "SchemaSyntaxError", + "GraphQLSyntaxError", + "InvalidTypeError", + "InvalidTargetError", + "InvalidLinkTargetError", + "InvalidPropertyTargetError", + "InvalidReferenceError", + "UnknownModuleError", + "UnknownLinkError", + "UnknownPropertyError", + "UnknownUserError", + "UnknownDatabaseError", + "UnknownParameterError", + "SchemaError", + "SchemaDefinitionError", + "InvalidDefinitionError", + "InvalidModuleDefinitionError", + "InvalidLinkDefinitionError", + "InvalidPropertyDefinitionError", + "InvalidUserDefinitionError", + "InvalidDatabaseDefinitionError", + "InvalidOperatorDefinitionError", + "InvalidAliasDefinitionError", + "InvalidFunctionDefinitionError", + "InvalidConstraintDefinitionError", + "InvalidCastDefinitionError", + "DuplicateDefinitionError", + "DuplicateModuleDefinitionError", + "DuplicateLinkDefinitionError", + "DuplicatePropertyDefinitionError", + "DuplicateUserDefinitionError", + "DuplicateDatabaseDefinitionError", + "DuplicateOperatorDefinitionError", + "DuplicateViewDefinitionError", + "DuplicateFunctionDefinitionError", + "DuplicateConstraintDefinitionError", + "DuplicateCastDefinitionError", + "DuplicateMigrationError", + "SessionTimeoutError", + "IdleSessionTimeoutError", + "QueryTimeoutError", + "TransactionTimeoutError", + "IdleTransactionTimeoutError", + "ExecutionError", + "InvalidValueError", + "DivisionByZeroError", + "NumericOutOfRangeError", + "AccessPolicyError", + "QueryAssertionError", + "IntegrityError", + "ConstraintViolationError", + "CardinalityViolationError", + "MissingRequiredError", + "TransactionError", + "TransactionConflictError", + "TransactionSerializationError", + "TransactionDeadlockError", + "WatchError", + "ConfigurationError", + "AccessError", + "AuthenticationError", + "AvailabilityError", + "BackendUnavailableError", + "BackendError", + "UnsupportedBackendFeatureError", + "LogMessage", + "WarningMessage", + "ClientError", + "ClientConnectionError", + "ClientConnectionFailedError", + "ClientConnectionFailedTemporarilyError", + "ClientConnectionTimeoutError", + "ClientConnectionClosedError", + "InterfaceError", + "QueryArgumentError", + "MissingArgumentError", + "UnknownArgumentError", + "InvalidArgumentError", + "NoDataError", + "InternalClientError", ) @@ -509,4 +509,3 @@ class NoDataError(ClientError): class InternalClientError(ClientError): _code = 0x_FF_04_00_00 - diff --git a/edgedb/errors/_base.py b/edgedb/errors/_base.py index 675ef567..6fe67ce2 100644 --- a/edgedb/errors/_base.py +++ b/edgedb/errors/_base.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import io import os @@ -24,16 +24,16 @@ import warnings __all__ = ( - 'EdgeDBError', 'EdgeDBMessage', + "EdgeDBError", + "EdgeDBMessage", ) class Meta(type): - def __new__(mcls, name, bases, dct): cls = super().__new__(mcls, name, bases, dct) - code = dct.get('_code') + code = dct.get("_code") if code is not None: mcls._index[code] = cls @@ -46,13 +46,11 @@ def __new__(mcls, name, bases, dct): class EdgeDBMessageMeta(Meta): - _base_class_index = {} _index = {} class EdgeDBMessage(Warning, metaclass=EdgeDBMessageMeta): - _code = None def __init__(self, severity, message): @@ -77,13 +75,11 @@ def _from_code(code, severity, message, *args, **kwargs): class EdgeDBErrorMeta(Meta): - _base_class_index = {} _index = {} class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta): - _code = None _query = None tags = frozenset() @@ -128,7 +124,7 @@ def _hint(self): def _read_str_field(self, key, default=None): val = self._attrs.get(key) if val: - return val.decode('utf-8') + return val.decode("utf-8") return default def get_code(self): @@ -207,23 +203,23 @@ def _lookup_message_cls(code: int): def _decode(code: int): - return tuple(code.to_bytes(4, 'big')) + return tuple(code.to_bytes(4, "big")) def _severity_name(severity): if severity <= EDGE_SEVERITY_DEBUG: - return 'DEBUG' + return "DEBUG" if severity <= EDGE_SEVERITY_INFO: - return 'INFO' + return "INFO" if severity <= EDGE_SEVERITY_NOTICE: - return 'NOTICE' + return "NOTICE" if severity <= EDGE_SEVERITY_WARNING: - return 'WARNING' + return "WARNING" if severity <= EDGE_SEVERITY_ERROR: - return 'ERROR' + return "ERROR" if severity <= EDGE_SEVERITY_FATAL: - return 'FATAL' - return 'PANIC' + return "FATAL" + return "PANIC" def _format_error(msg, query, start, offset, line, col, hint): @@ -259,8 +255,10 @@ def _format_error(msg, query, start, offset, line, col, hint): rv.write(f"{c.FAIL}{line}{c.ENDC}{LINESEP}") if start >= 0: # Multi-line error starts - rv.write(f"{c.BLUE}{'':>{num_len}} │ " - f"{c.FAIL}╭─{'─' * start}^{c.ENDC}{LINESEP}") + rv.write( + f"{c.BLUE}{'':>{num_len}} │ " + f"{c.FAIL}╭─{'─' * start}^{c.ENDC}{LINESEP}" + ) offset -= length start = -1 # mark multi-line else: @@ -271,20 +269,29 @@ def _format_error(msg, query, start, offset, line, col, hint): size = _unicode_width(first_half) if start >= 0: # Mark single-line error - rv.write(f"{c.BLUE}{'':>{num_len}} │ {' ' * start}" - f"{c.FAIL}{'^' * size} {hint}{c.ENDC}") + rv.write( + f"{c.BLUE}{'':>{num_len}} │ {' ' * start}" + f"{c.FAIL}{'^' * size} {hint}{c.ENDC}" + ) else: # End of multi-line error - rv.write(f"{c.BLUE}{'':>{num_len}} │ " - f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}") + rv.write( + f"{c.BLUE}{'':>{num_len}} │ " + f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}" + ) break return rv.getvalue() def _unicode_width(text): - return sum(0 if unicodedata.category(c) in ('Mn', 'Cf') else - 2 if unicodedata.east_asian_width(c) == "W" else 1 - for c in text) + return sum( + 0 + if unicodedata.category(c) in ("Mn", "Cf") + else 2 + if unicodedata.east_asian_width(c) == "W" + else 1 + for c in text + ) FIELD_HINT = 0x_00_01 diff --git a/edgedb/errors/tags.py b/edgedb/errors/tags.py index 275b31ac..8f86c5ac 100644 --- a/edgedb/errors/tags.py +++ b/edgedb/errors/tags.py @@ -1,12 +1,14 @@ +from __future__ import annotations + __all__ = [ - 'Tag', - 'SHOULD_RECONNECT', - 'SHOULD_RETRY', + "Tag", + "SHOULD_RECONNECT", + "SHOULD_RETRY", ] -class Tag(object): - """Error tag +class Tag: + """Error tag. Tags are used to differentiate certain properties of errors that apply to error classes across hierarchy. @@ -18,8 +20,8 @@ def __init__(self, name): self.name = name def __repr__(self): - return f'' + return f"" -SHOULD_RECONNECT = Tag('SHOULD_RECONNECT') -SHOULD_RETRY = Tag('SHOULD_RETRY') +SHOULD_RECONNECT = Tag("SHOULD_RECONNECT") +SHOULD_RETRY = Tag("SHOULD_RETRY") diff --git a/edgedb/introspect.py b/edgedb/introspect.py index 0db045e6..15074f90 100644 --- a/edgedb/introspect.py +++ b/edgedb/introspect.py @@ -15,10 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - # IMPORTANT: this private API is subject to change. - +from __future__ import annotations import functools import typing @@ -28,18 +26,16 @@ class PointerDescription(typing.NamedTuple): - name: str kind: ElementKind implicit: bool class ObjectDescription(typing.NamedTuple): - - pointers: typing.Tuple[PointerDescription, ...] + pointers: tuple[PointerDescription, ...] -@functools.lru_cache() +@functools.lru_cache def _introspect_object_desc(desc) -> ObjectDescription: pointers = [] # Call __dir__ directly as dir() scrambles the order. @@ -53,14 +49,12 @@ def _introspect_object_desc(desc) -> ObjectDescription: pointers.append( PointerDescription( - name=name, - kind=kind, - implicit=desc.is_implicit(name))) + name=name, kind=kind, implicit=desc.is_implicit(name) + ) + ) - return ObjectDescription( - pointers=tuple(pointers)) + return ObjectDescription(pointers=tuple(pointers)) def introspect_object(obj) -> ObjectDescription: - return _introspect_object_desc( - dt.get_object_descriptor(obj)) + return _introspect_object_desc(dt.get_object_descriptor(obj)) diff --git a/edgedb/options.py b/edgedb/options.py index 714168ab..8a359050 100644 --- a/edgedb/options.py +++ b/edgedb/options.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import enum import random @@ -6,28 +8,30 @@ from . import errors - _RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"]) def default_backoff(attempt): - return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001 + return (2**attempt) * 0.1 + random.randrange(100) * 0.001 class RetryCondition: - """Specific condition to retry on for fine-grained control""" + """Specific condition to retry on for fine-grained control.""" + TransactionConflict = enum.auto() NetworkError = enum.auto() class IsolationLevel: - """Isolation level for transaction""" + """Isolation level for transaction.""" + Serializable = "SERIALIZABLE" class RetryOptions: - """An immutable class that contains rules for `transaction()`""" - __slots__ = ['_default', '_overrides'] + """An immutable class that contains rules for `transaction()`.""" + + __slots__ = ["_default", "_overrides"] def __init__(self, attempts: int, backoff=default_backoff): self._default = _RetryRule(attempts, backoff) @@ -69,12 +73,13 @@ def get_rule_for_exception(self, exception): class TransactionOptions: - """Options for `transaction()`""" - __slots__ = ['_isolation', '_readonly', '_deferrable'] + """Options for `transaction()`.""" + + __slots__ = ["_isolation", "_readonly", "_deferrable"] def __init__( self, - isolation: IsolationLevel=IsolationLevel.Serializable, + isolation: IsolationLevel = IsolationLevel.Serializable, readonly: bool = False, deferrable: bool = False, ): @@ -89,35 +94,35 @@ def defaults(cls): def start_transaction_query(self): isolation = str(self._isolation) if self._readonly: - mode = 'READ ONLY' + mode = "READ ONLY" else: - mode = 'READ WRITE' + mode = "READ WRITE" if self._deferrable: - defer = 'DEFERRABLE' + defer = "DEFERRABLE" else: - defer = 'NOT DEFERRABLE' + defer = "NOT DEFERRABLE" - return f'START TRANSACTION ISOLATION {isolation}, {mode}, {defer};' + return f"START TRANSACTION ISOLATION {isolation}, {mode}, {defer};" def __repr__(self): return ( - f'<{self.__class__.__name__} ' - f'isolation:{self._isolation}, ' - f'readonly:{self._readonly}, ' - f'deferrable:{self._deferrable}>' + f"<{self.__class__.__name__} " + f"isolation:{self._isolation}, " + f"readonly:{self._readonly}, " + f"deferrable:{self._deferrable}>" ) class State: - __slots__ = ['_module', '_aliases', '_config', '_globals'] + __slots__ = ["_module", "_aliases", "_config", "_globals"] def __init__( self, - default_module: typing.Optional[str] = None, - module_aliases: typing.Mapping[str, str] = None, - config: typing.Mapping[str, typing.Any] = None, - globals_: typing.Mapping[str, typing.Any] = None, + default_module: str | None = None, + module_aliases: typing.Mapping[str, str] | None = None, + config: typing.Mapping[str, typing.Any] | None = None, + globals_: typing.Mapping[str, typing.Any] | None = None, ): self._module = default_module self._aliases = {} if module_aliases is None else dict(module_aliases) @@ -139,7 +144,7 @@ def _new(cls, default_module, module_aliases, config, globals_): def defaults(cls): return cls() - def with_default_module(self, module: typing.Optional[str] = None): + def with_default_module(self, module: str | None = None): return self._new( default_module=module, module_aliases=self._aliases, @@ -275,7 +280,9 @@ def __init__(self, *args, **kwargs): def _shallow_clone(self): pass - def with_transaction_options(self, options: TransactionOptions = None): + def with_transaction_options( + self, options: TransactionOptions | None = None + ): """Returns object with adjusted options for future transactions. :param options TransactionOptions: @@ -293,7 +300,7 @@ def with_transaction_options(self, options: TransactionOptions = None): result._options = self._options.with_transaction_options(options) return result - def with_retry_options(self, options: RetryOptions=None): + def with_retry_options(self, options: RetryOptions | None = None): """Returns object with adjusted options for future retrying transactions. @@ -306,7 +313,6 @@ def with_retry_options(self, options: RetryOptions=None): Both ``self`` and returned object can be used after, but when using them retry options applied will be different. """ - result = self._shallow_clone() result._options = self._options.with_retry_options(options) return result @@ -316,7 +322,7 @@ def with_state(self, state: State): result._options = self._options.with_state(state) return result - def with_default_module(self, module: typing.Optional[str] = None): + def with_default_module(self, module: str | None = None): result = self._shallow_clone() result._options = self._options.with_state( self._options.state.with_default_module(module) @@ -367,9 +373,9 @@ def without_globals(self, *global_names): class _Options: - """Internal class for storing connection options""" + """Internal class for storing connection options.""" - __slots__ = ['_retry_options', '_transaction_options', '_state'] + __slots__ = ["_retry_options", "_transaction_options", "_state"] def __init__( self, diff --git a/edgedb/platform.py b/edgedb/platform.py index 55410532..814fd96c 100644 --- a/edgedb/platform.py +++ b/edgedb/platform.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import functools import os import pathlib import sys if sys.platform == "darwin": + def config_dir() -> pathlib.Path: return ( pathlib.Path.home() / "Library" / "Application Support" / "edgedb" @@ -24,6 +27,7 @@ def config_dir() -> pathlib.Path: IS_WINDOWS = True else: + def config_dir() -> pathlib.Path: xdg_conf_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", ".")) if not xdg_conf_dir.is_absolute(): diff --git a/edgedb/scram/__init__.py b/edgedb/scram/__init__.py index 62f3e712..71579b62 100644 --- a/edgedb/scram/__init__.py +++ b/edgedb/scram/__init__.py @@ -17,6 +17,7 @@ # """Helpers for SCRAM authentication.""" +from __future__ import annotations import base64 import hashlib @@ -26,7 +27,6 @@ from .saslprep import saslprep - RAW_NONCE_LENGTH = 18 # Per recommendations in RFC 7677. @@ -42,8 +42,12 @@ def generate_nonce(length: int = RAW_NONCE_LENGTH) -> str: return B64(os.urandom(length)) -def build_verifier(password: str, *, salt: typing.Optional[bytes] = None, - iterations: int = DEFAULT_ITERATIONS) -> str: +def build_verifier( + password: str, + *, + salt: bytes | None = None, + iterations: int = DEFAULT_ITERATIONS, +) -> str: """Build the SCRAM verifier for the given password. Returns a string in the following format: @@ -52,7 +56,7 @@ def build_verifier(password: str, *, salt: typing.Optional[bytes] = None, The salt and keys are base64-encoded values. """ - password = saslprep(password).encode('utf-8') + password = saslprep(password).encode("utf-8") if salt is None: salt = generate_salt() @@ -62,12 +66,13 @@ def build_verifier(password: str, *, salt: typing.Optional[bytes] = None, stored_key = H(client_key) server_key = get_server_key(salted_password) - return (f'SCRAM-SHA-256${iterations}:{B64(salt)}$' - f'{B64(stored_key)}:{B64(server_key)}') + return ( + f"SCRAM-SHA-256${iterations}:{B64(salt)}$" + f"{B64(stored_key)}:{B64(server_key)}" + ) class SCRAMVerifier(typing.NamedTuple): - mechanism: str iterations: int salt: bytes @@ -76,24 +81,23 @@ class SCRAMVerifier(typing.NamedTuple): def parse_verifier(verifier: str) -> SCRAMVerifier: - - parts = verifier.split('$') + parts = verifier.split("$") if len(parts) != 3: - raise ValueError('invalid SCRAM verifier') + raise ValueError("invalid SCRAM verifier") mechanism = parts[0] - if mechanism != 'SCRAM-SHA-256': - raise ValueError('invalid SCRAM verifier') + if mechanism != "SCRAM-SHA-256": + raise ValueError("invalid SCRAM verifier") - iterations, _, salt = parts[1].partition(':') - stored_key, _, server_key = parts[2].partition(':') + iterations, _, salt = parts[1].partition(":") + stored_key, _, server_key = parts[2].partition(":") if not salt or not server_key: - raise ValueError('invalid SCRAM verifier') + raise ValueError("invalid SCRAM verifier") try: iterations = int(iterations) except ValueError: - raise ValueError('invalid SCRAM verifier') from None + raise ValueError("invalid SCRAM verifier") from None return SCRAMVerifier( mechanism=mechanism, @@ -105,7 +109,6 @@ def parse_verifier(verifier: str) -> SCRAMVerifier: def parse_client_first_message(resp: bytes): - # Relevant bits of RFC 5802: # # saslname = 1*(value-safe-char / "=2C" / "=3D") @@ -153,44 +156,44 @@ def parse_client_first_message(resp: bytes): # client-first-message = # gs2-header client-first-message-bare - attrs = resp.split(b',') + attrs = resp.split(b",") cb_attr = attrs[0] - if cb_attr == b'y': + if cb_attr == b"y": cb = True - elif cb_attr == b'n': + elif cb_attr == b"n": cb = False - elif cb_attr[0:1] == b'p': - _, _, cb = cb_attr.partition(b'=') + elif cb_attr[0:1] == b"p": + _, _, cb = cb_attr.partition(b"=") if not cb: - raise ValueError('malformed SCRAM message') + raise ValueError("malformed SCRAM message") else: - raise ValueError('malformed SCRAM message') + raise ValueError("malformed SCRAM message") authzid_attr = attrs[1] if authzid_attr: - if authzid_attr[0:1] != b'a': - raise ValueError('malformed SCRAM message') - _, _, authzid = authzid_attr.partition(b'=') + if authzid_attr[0:1] != b"a": + raise ValueError("malformed SCRAM message") + _, _, authzid = authzid_attr.partition(b"=") else: authzid = None user_attr = attrs[2] - if user_attr[0:1] == b'm': - raise ValueError('unsupported SCRAM extensions in message') - elif user_attr[0:1] != b'n': - raise ValueError('malformed SCRAM message') + if user_attr[0:1] == b"m": + raise ValueError("unsupported SCRAM extensions in message") + elif user_attr[0:1] != b"n": + raise ValueError("malformed SCRAM message") - _, _, user = user_attr.partition(b'=') + _, _, user = user_attr.partition(b"=") nonce_attr = attrs[3] - if nonce_attr[0:1] != b'r': - raise ValueError('malformed SCRAM message') + if nonce_attr[0:1] != b"r": + raise ValueError("malformed SCRAM message") - _, _, nonce_bin = nonce_attr.partition(b'=') - nonce = nonce_bin.decode('ascii') + _, _, nonce_bin = nonce_attr.partition(b"=") + nonce = nonce_bin.decode("ascii") if not nonce.isprintable(): - raise ValueError('invalid characters in client nonce') + raise ValueError("invalid characters in client nonce") # ["," extensions] are ignored @@ -198,8 +201,8 @@ def parse_client_first_message(resp: bytes): def parse_client_final_message( - msg: bytes, client_nonce: str, server_nonce: str): - + msg: bytes, client_nonce: str, server_nonce: str +): # Relevant bits of RFC 5802: # # gs2-header = gs2-cbind-flag "," [ authzid ] "," @@ -226,82 +229,78 @@ def parse_client_final_message( # client-final-message = # client-final-message-without-proof "," proof - attrs = msg.split(b',') + attrs = msg.split(b",") cb_attr = attrs[0] - if cb_attr[0:1] != b'c': - raise ValueError('malformed SCRAM message') + if cb_attr[0:1] != b"c": + raise ValueError("malformed SCRAM message") - _, _, cb_data = cb_attr.partition(b'=') + _, _, cb_data = cb_attr.partition(b"=") nonce_attr = attrs[1] - if nonce_attr[0:1] != b'r': - raise ValueError('malformed SCRAM message') + if nonce_attr[0:1] != b"r": + raise ValueError("malformed SCRAM message") - _, _, nonce_bin = nonce_attr.partition(b'=') - nonce = nonce_bin.decode('ascii') + _, _, nonce_bin = nonce_attr.partition(b"=") + nonce = nonce_bin.decode("ascii") - expected_nonce = f'{client_nonce}{server_nonce}' + expected_nonce = f"{client_nonce}{server_nonce}" if nonce != expected_nonce: raise ValueError( - 'invalid SCRAM client-final message: nonce does not match') + "invalid SCRAM client-final message: nonce does not match" + ) proof = None for attr in attrs[2:]: - if attr[0:1] == b'p': - _, _, proof = attr.partition(b'=') + if attr[0:1] == b"p": + _, _, proof = attr.partition(b"=") proof_attr_len = len(attr) proof = base64.b64decode(proof) elif proof is not None: - raise ValueError('malformed SCRAM message') + raise ValueError("malformed SCRAM message") if proof is None: - raise ValueError('malformed SCRAM message') + raise ValueError("malformed SCRAM message") return cb_data, proof, proof_attr_len + 1 def build_client_first_message(client_nonce: str, username: str) -> str: - - bare = f'n={saslprep(username)},r={client_nonce}' - return f'n,,{bare}', bare + bare = f"n={saslprep(username)},r={client_nonce}" + return f"n,,{bare}", bare -def build_server_first_message(server_nonce: str, client_nonce: str, - salt: bytes, iterations: int) -> str: - - return ( - f'r={client_nonce}{server_nonce},' - f's={B64(salt)},i={iterations}' - ) +def build_server_first_message( + server_nonce: str, client_nonce: str, salt: bytes, iterations: int +) -> str: + return f"r={client_nonce}{server_nonce}," f"s={B64(salt)},i={iterations}" def build_auth_message( - client_first_bare: bytes, - server_first: bytes, client_final: bytes) -> bytes: - - return b'%b,%b,%b' % (client_first_bare, server_first, client_final) + client_first_bare: bytes, server_first: bytes, client_final: bytes +) -> bytes: + return b"%b,%b,%b" % (client_first_bare, server_first, client_final) def build_client_final_message( - password: str, - salt: bytes, - iterations: int, - client_first_bare: bytes, - server_first: bytes, - server_nonce: str) -> str: - - client_final = f'c=biws,r={server_nonce}' + password: str, + salt: bytes, + iterations: int, + client_first_bare: bytes, + server_first: bytes, + server_nonce: str, +) -> str: + client_final = f"c=biws,r={server_nonce}" AuthMessage = build_auth_message( - client_first_bare, server_first, client_final.encode('utf-8')) + client_first_bare, server_first, client_final.encode("utf-8") + ) SaltedPassword = get_salted_password( - saslprep(password).encode('utf-8'), - salt, - iterations) + saslprep(password).encode("utf-8"), salt, iterations + ) ClientKey = get_client_key(SaltedPassword) StoredKey = H(ClientKey) @@ -311,62 +310,63 @@ def build_client_final_message( ServerKey = get_server_key(SaltedPassword) ServerProof = HMAC(ServerKey, AuthMessage) - return f'{client_final},p={B64(ClientProof)}', ServerProof + return f"{client_final},p={B64(ClientProof)}", ServerProof def build_server_final_message( - client_first_bare: bytes, server_first: bytes, - client_final: bytes, server_key: bytes) -> str: - + client_first_bare: bytes, + server_first: bytes, + client_final: bytes, + server_key: bytes, +) -> str: AuthMessage = build_auth_message( - client_first_bare, server_first, client_final) + client_first_bare, server_first, client_final + ) ServerSignature = HMAC(server_key, AuthMessage) - return f'v={B64(ServerSignature)}' + return f"v={B64(ServerSignature)}" def parse_server_first_message(msg: bytes): - - attrs = msg.split(b',') + attrs = msg.split(b",") nonce_attr = attrs[0] - if nonce_attr[0:1] != b'r': - raise ValueError('malformed SCRAM message') + if nonce_attr[0:1] != b"r": + raise ValueError("malformed SCRAM message") - _, _, nonce_bin = nonce_attr.partition(b'=') - nonce = nonce_bin.decode('ascii') + _, _, nonce_bin = nonce_attr.partition(b"=") + nonce = nonce_bin.decode("ascii") if not nonce.isprintable(): - raise ValueError('malformed SCRAM message') + raise ValueError("malformed SCRAM message") salt_attr = attrs[1] - if salt_attr[0:1] != b's': - raise ValueError('malformed SCRAM message') + if salt_attr[0:1] != b"s": + raise ValueError("malformed SCRAM message") - _, _, salt_b64 = salt_attr.partition(b'=') + _, _, salt_b64 = salt_attr.partition(b"=") salt = base64.b64decode(salt_b64) iter_attr = attrs[2] - if iter_attr[0:1] != b'i': - raise ValueError('malformed SCRAM message') + if iter_attr[0:1] != b"i": + raise ValueError("malformed SCRAM message") - _, _, iterations = iter_attr.partition(b'=') + _, _, iterations = iter_attr.partition(b"=") try: itercount = int(iterations) except ValueError: - raise ValueError('malformed SCRAM message') from None + raise ValueError("malformed SCRAM message") from None return nonce, salt, itercount def parse_server_final_message(msg: bytes): - - attrs = msg.split(b',') + attrs = msg.split(b",") nonce_attr = attrs[0] - if nonce_attr[0:1] != b'v': - raise ValueError('malformed SCRAM message') + if nonce_attr[0:1] != b"v": + raise ValueError("malformed SCRAM message") - _, _, signature_b64 = nonce_attr.partition(b'=') + _, _, signature_b64 = nonce_attr.partition(b"=") signature = base64.b64decode(signature_b64) return signature @@ -377,17 +377,20 @@ def verify_password(password: bytes, verifier: str) -> bool: Returns True if the password is OK, False otherwise. """ - - password = saslprep(password).encode('utf-8') + password = saslprep(password).encode("utf-8") v = parse_verifier(verifier) salted_password = get_salted_password(password, v.salt, v.iterations) computed_key = get_server_key(salted_password) return v.server_key == computed_key -def verify_client_proof(client_first: bytes, server_first: bytes, - client_final: bytes, StoredKey: bytes, - ClientProof: bytes) -> bool: +def verify_client_proof( + client_first: bytes, + server_first: bytes, + client_final: bytes, + StoredKey: bytes, + ClientProof: bytes, +) -> bool: AuthMessage = build_auth_message(client_first, server_first, client_final) ClientSignature = HMAC(StoredKey, AuthMessage) ClientKey = XOR(ClientProof, ClientSignature) @@ -405,19 +408,20 @@ def HMAC(key: bytes, msg: bytes) -> bytes: def XOR(a: bytes, b: bytes) -> bytes: if len(a) != len(b): - raise ValueError('scram.XOR received operands of unequal length') - xint = int.from_bytes(a, 'big') ^ int.from_bytes(b, 'big') - return xint.to_bytes(len(a), 'big') + raise ValueError("scram.XOR received operands of unequal length") + xint = int.from_bytes(a, "big") ^ int.from_bytes(b, "big") + return xint.to_bytes(len(a), "big") def H(s: bytes) -> bytes: return hashlib.sha256(s).digest() -def get_salted_password(password: bytes, salt: bytes, - iterations: int) -> bytes: +def get_salted_password( + password: bytes, salt: bytes, iterations: int +) -> bytes: # U1 := HMAC(str, salt + INT(1)) - H_i = U_i = HMAC(password, salt + b'\x00\x00\x00\x01') + H_i = U_i = HMAC(password, salt + b"\x00\x00\x00\x01") for _ in range(iterations - 1): U_i = HMAC(password, U_i) @@ -427,8 +431,8 @@ def get_salted_password(password: bytes, salt: bytes, def get_client_key(salted_password: bytes) -> bytes: - return HMAC(salted_password, b'Client Key') + return HMAC(salted_password, b"Client Key") def get_server_key(salted_password: bytes) -> bytes: - return HMAC(salted_password, b'Server Key') + return HMAC(salted_password, b"Server Key") diff --git a/edgedb/scram/saslprep.py b/edgedb/scram/saslprep.py index 79eb84d8..f85222f3 100644 --- a/edgedb/scram/saslprep.py +++ b/edgedb/scram/saslprep.py @@ -1,4 +1,6 @@ # Copyright 2016-present MongoDB, Inc. +from __future__ import annotations + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,11 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import stringprep import unicodedata - # RFC4013 section 2.3 prohibited output. _PROHIBITED = ( # A strict reading of RFC 4013 requires table c12 here, but @@ -29,17 +29,17 @@ stringprep.in_table_c6, stringprep.in_table_c7, stringprep.in_table_c8, - stringprep.in_table_c9) + stringprep.in_table_c9, +) def saslprep(data: str, prohibit_unassigned_code_points=True): """An implementation of RFC4013 SASLprep.""" - - if data == '': + if data == "": return data if prohibit_unassigned_code_points: - prohibited = _PROHIBITED + (stringprep.in_table_a1,) + prohibited = (*_PROHIBITED, stringprep.in_table_a1) else: prohibited = _PROHIBITED @@ -49,13 +49,17 @@ def saslprep(data: str, prohibit_unassigned_code_points=True): # commonly mapped to nothing characters to, well, nothing. in_table_c12 = stringprep.in_table_c12 in_table_b1 = stringprep.in_table_b1 - data = u"".join( - [u"\u0020" if in_table_c12(elt) else elt - for elt in data if not in_table_b1(elt)]) + data = "".join( + [ + "\u0020" if in_table_c12(elt) else elt + for elt in data + if not in_table_b1(elt) + ] + ) # RFC3454 section 2, step 2 - Normalize # RFC4013 section 2.2 normalization - data = unicodedata.ucd_3_2_0.normalize('NFKC', data) + data = unicodedata.ucd_3_2_0.normalize("NFKC", data) in_table_d1 = stringprep.in_table_d1 if in_table_d1(data[0]): @@ -66,17 +70,16 @@ def saslprep(data: str, prohibit_unassigned_code_points=True): raise ValueError("SASLprep: failed bidirectional check") # RFC3454, Section 6, #2. If a string contains any RandALCat # character, it MUST NOT contain any LCat character. - prohibited = prohibited + (stringprep.in_table_d2,) + prohibited = (*prohibited, stringprep.in_table_d2) else: # RFC3454, Section 6, #3. Following the logic of #3, if # the first character is not a RandALCat, no other character # can be either. - prohibited = prohibited + (in_table_d1,) + prohibited = (*prohibited, in_table_d1) # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi for char in data: if any(in_table(char) for in_table in prohibited): - raise ValueError( - "SASLprep: failed prohibited character check") + raise ValueError("SASLprep: failed prohibited character check") return data diff --git a/edgedb/transaction.py b/edgedb/transaction.py index 511b8f42..05f67dc0 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -15,13 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import enum -from . import abstract -from . import errors -from . import options +from . import abstract, errors, options class TransactionState(enum.Enum): @@ -33,15 +31,14 @@ class TransactionState(enum.Enum): class BaseTransaction: - __slots__ = ( - '_client', - '_connection', - '_options', - '_state', - '__retry', - '__iteration', - '__started', + "_client", + "_connection", + "_options", + "_state", + "__retry", + "__iteration", + "__started", ) def __init__(self, retry, client, iteration): @@ -59,53 +56,55 @@ def is_active(self) -> bool: def __check_state_base(self, opname): if self._state is TransactionState.COMMITTED: raise errors.InterfaceError( - 'cannot {}; the transaction is already committed'.format( - opname)) + f"cannot {opname}; the transaction is already committed" + ) if self._state is TransactionState.ROLLEDBACK: raise errors.InterfaceError( - 'cannot {}; the transaction is already rolled back'.format( - opname)) + f"cannot {opname}; the transaction is already rolled back" + ) if self._state is TransactionState.FAILED: raise errors.InterfaceError( - 'cannot {}; the transaction is in error state'.format( - opname)) + f"cannot {opname}; the transaction is in error state" + ) def __check_state(self, opname): if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise errors.InterfaceError( - 'cannot {}; the transaction is not yet started'.format( - opname)) + f"cannot {opname}; the transaction is not yet started" + ) self.__check_state_base(opname) def _make_start_query(self): - self.__check_state_base('start') + self.__check_state_base("start") if self._state is TransactionState.STARTED: raise errors.InterfaceError( - 'cannot start; the transaction is already started') + "cannot start; the transaction is already started" + ) return self._options.start_transaction_query() def _make_commit_query(self): - self.__check_state('commit') - return 'COMMIT;' + self.__check_state("commit") + return "COMMIT;" def _make_rollback_query(self): - self.__check_state('rollback') - return 'ROLLBACK;' + self.__check_state("rollback") + return "ROLLBACK;" def __repr__(self): attrs = [] - attrs.append('state:{}'.format(self._state.name.lower())) + attrs.append(f"state:{self._state.name.lower()}") attrs.append(repr(self._options)) - if self.__class__.__module__.startswith('edgedb.'): - mod = 'edgedb' + if self.__class__.__module__.startswith("edgedb."): + mod = "edgedb" else: mod = self.__class__.__module__ - return '<{}.{} {} {:#x}>'.format( - mod, self.__class__.__name__, ' '.join(attrs), id(self)) + return "<{}.{} {} {:#x}>".format( + mod, self.__class__.__name__, " ".join(attrs), id(self) + ) async def _ensure_transaction(self): if not self.__started: @@ -173,9 +172,9 @@ async def _exit(self, extype, ex): await self._client._impl.release(self._connection) if ( - extype is not None and - issubclass(extype, errors.EdgeDBError) and - ex.has_tag(errors.SHOULD_RETRY) + extype is not None + and issubclass(extype, errors.EdgeDBError) + and ex.has_tag(errors.SHOULD_RETRY) ): return self.__retry._retry(ex) @@ -194,15 +193,16 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None: await self._connection._execute(execute_context) async def _privileged_execute(self, query: str) -> None: - await self._connection.privileged_execute(abstract.ExecuteContext( - query=abstract.QueryWithArgs(query, (), {}), - cache=self._get_query_cache(), - state=self._get_state(), - )) + await self._connection.privileged_execute( + abstract.ExecuteContext( + query=abstract.QueryWithArgs(query, (), {}), + cache=self._get_query_cache(), + state=self._get_state(), + ) + ) class BaseRetry: - def __init__(self, owner): self._owner = owner self._iteration = 0 diff --git a/setup.py b/setup.py index 0d9b86c6..ccdf5865 100644 --- a/setup.py +++ b/setup.py @@ -15,12 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import sys if sys.version_info < (3, 8): - raise RuntimeError('edgedb requires Python 3.8 or greater') + raise RuntimeError("edgedb requires Python 3.8 or greater") import os import os.path @@ -37,92 +37,104 @@ from setuptools.command import build_py as setuptools_build_py from setuptools.command import sdist as setuptools_sdist - -CYTHON_DEPENDENCY = 'Cython(>=0.29.24,<0.30.0)' +CYTHON_DEPENDENCY = "Cython(>=0.29.24,<0.30.0)" # Minimal dependencies required to test edgedb. TEST_DEPENDENCIES = [ # pycodestyle is a dependency of flake8, but it must be frozen because # their combination breaks too often # (example breakage: https://gitlab.com/pycqa/flake8/issues/427) - 'pycodestyle~=2.6.0', - 'pyflakes~=2.2.0', - 'flake8-bugbear~=21.4.3', - 'flake8~=3.8.1', + "pycodestyle~=2.6.0", + "pyflakes~=2.2.0", + "flake8-bugbear~=21.4.3", + "flake8~=3.8.1", 'uvloop>=0.15.1; platform_system != "Windows"', ] # Dependencies required to build documentation. DOC_DEPENDENCIES = [ - 'sphinx~=4.2.0', - 'sphinxcontrib-asyncio~=0.3.0', - 'sphinx_rtd_theme~=1.0.0', + "sphinx~=4.2.0", + "sphinxcontrib-asyncio~=0.3.0", + "sphinx_rtd_theme~=1.0.0", ] EXTRA_DEPENDENCIES = { - 'docs': DOC_DEPENDENCIES, - 'test': TEST_DEPENDENCIES, + "docs": DOC_DEPENDENCIES, + "test": TEST_DEPENDENCIES, # Dependencies required to develop edgedb. - 'dev': [ + "dev": [ CYTHON_DEPENDENCY, - 'pytest>=3.6.0', - ] + DOC_DEPENDENCIES + TEST_DEPENDENCIES + "pytest>=3.6.0", + *DOC_DEPENDENCIES, + *TEST_DEPENDENCIES, + ], } -CFLAGS = ['-O2'] +CFLAGS = ["-O2"] LDFLAGS = [] SYSTEM = sys.platform -if SYSTEM != 'win32': - CFLAGS.extend(['-std=gnu99', '-fsigned-char', '-Wall', - '-Wsign-compare', '-Wconversion']) +if SYSTEM != "win32": + CFLAGS.extend( + [ + "-std=gnu99", + "-fsigned-char", + "-Wall", + "-Wsign-compare", + "-Wconversion", + ] + ) -if SYSTEM == 'darwin': +if SYSTEM == "darwin": # Lots of warnings from the standard library on macOS 10.14 - CFLAGS.extend(['-Wno-nullability-completeness']) + CFLAGS.extend(["-Wno-nullability-completeness"]) _ROOT = pathlib.Path(__file__).parent -with open(str(_ROOT / 'README.rst')) as f: +with open(str(_ROOT / "README.rst")) as f: readme = f.read() -with open(str(_ROOT / 'edgedb' / '_version.py')) as f: +with open(str(_ROOT / "edgedb" / "_version.py")) as f: for line in f: - if line.startswith('__version__ ='): - _, _, version = line.partition('=') + if line.startswith("__version__ ="): + _, _, version = line.partition("=") VERSION = version.strip(" \n'\"") break else: raise RuntimeError( - 'unable to read the version from edgedb/_version.py') + "unable to read the version from edgedb/_version.py" + ) -if (_ROOT / '.git').is_dir() and 'dev' in VERSION: +if (_ROOT / ".git").is_dir() and "dev" in VERSION: # This is a git checkout, use git to # generate a precise version. def git_commitish(): env = {} - v = os.environ.get('PATH') + v = os.environ.get("PATH") if v is not None: - env['PATH'] = v - - git = subprocess.run(['git', 'rev-parse', 'HEAD'], env=env, - cwd=str(_ROOT), stdout=subprocess.PIPE) + env["PATH"] = v + + git = subprocess.run( + ["git", "rev-parse", "HEAD"], + env=env, + cwd=str(_ROOT), + stdout=subprocess.PIPE, + ) if git.returncode == 0: - commitish = git.stdout.strip().decode('ascii') + commitish = git.stdout.strip().decode("ascii") else: - commitish = 'unknown' + commitish = "unknown" return commitish - VERSION += '+' + git_commitish()[:7] + VERSION += "+" + git_commitish()[:7] class VersionMixin: - def _fix_version(self, filename): # Replace edgedb.__version__ with the actual version # of the distribution (possibly inferred from git). @@ -131,56 +143,59 @@ def _fix_version(self, filename): content = f.read() version_re = r"(.*__version__\s*=\s*)'[^']+'(.*)" - repl = r"\1'{}'\2".format(self.distribution.metadata.version) + repl = rf"\1'{self.distribution.metadata.version}'\2" content = re.sub(version_re, repl, content) - with open(str(filename), 'w') as f: + with open(str(filename), "w") as f: f.write(content) class sdist(setuptools_sdist.sdist, VersionMixin): - def make_release_tree(self, base_dir, files): super().make_release_tree(base_dir, files) - self._fix_version(pathlib.Path(base_dir) / 'edgedb' / '_version.py') + self._fix_version(pathlib.Path(base_dir) / "edgedb" / "_version.py") class build_py(setuptools_build_py.build_py, VersionMixin): - def build_module(self, module, module_file, package): outfile, copied = super().build_module(module, module_file, package) - if module == '_version' and package == 'edgedb': + if module == "_version" and package == "edgedb": self._fix_version(outfile) return outfile, copied class build_ext(distutils_build_ext.build_ext): - - user_options = distutils_build_ext.build_ext.user_options + [ - ('cython-always', None, - 'run cythonize() even if .c files are present'), - ('cython-annotate', None, - 'Produce a colorized HTML version of the Cython source.'), - ('cython-directives=', None, - 'Cython compiler directives'), + user_options = [ + *distutils_build_ext.build_ext.user_options, + ( + "cython-always", + None, + "run cythonize() even if .c files are present", + ), + ( + "cython-annotate", + None, + "Produce a colorized HTML version of the Cython source.", + ), + ("cython-directives=", None, "Cython compiler directives"), ] def initialize_options(self): # initialize_options() may be called multiple times on the # same command object, so make sure not to override previously # set options. - if getattr(self, '_initialized', False): + if getattr(self, "_initialized", False): return - super(build_ext, self).initialize_options() + super().initialize_options() - if os.environ.get('EDGEDB_DEBUG'): + if os.environ.get("EDGEDB_DEBUG"): self.cython_always = True self.cython_annotate = True self.cython_directives = "linetrace=True" - self.define = 'PG_DEBUG,CYTHON_TRACE,CYTHON_TRACE_NOGIL' + self.define = "PG_DEBUG,CYTHON_TRACE,CYTHON_TRACE_NOGIL" self.debug = True else: self.cython_always = False @@ -192,7 +207,7 @@ def finalize_options(self): # finalize_options() may be called multiple times on the # same command object, so make sure not to override previously # set options. - if getattr(self, '_initialized', False): + if getattr(self, "_initialized", False): return need_cythonize = self.cython_always @@ -200,9 +215,9 @@ def finalize_options(self): for extension in self.distribution.ext_modules: for i, sfile in enumerate(extension.sources): - if sfile.endswith('.pyx'): + if sfile.endswith(".pyx"): prefix, ext = os.path.splitext(sfile) - cfile = prefix + '.c' + cfile = prefix + ".c" if os.path.exists(cfile) and not self.cython_always: extension.sources[i] = cfile @@ -224,27 +239,28 @@ def finalize_options(self): import Cython except ImportError: raise RuntimeError( - 'please install {} to compile edgedb from source'.format( - CYTHON_DEPENDENCY)) + "please install {} to compile edgedb from source".format( + CYTHON_DEPENDENCY + ) + ) cython_dep = pkg_resources.Requirement.parse(CYTHON_DEPENDENCY) if Cython.__version__ not in cython_dep: raise RuntimeError( - 'edgedb requires {}, got Cython=={}'.format( + "edgedb requires {}, got Cython=={}".format( CYTHON_DEPENDENCY, Cython.__version__ - )) + ) + ) from Cython.Build import cythonize - directives = { - 'language_level': '3' - } + directives = {"language_level": "3"} if self.cython_directives: - for directive in self.cython_directives.split(','): - k, _, v = directive.partition('=') - if v.lower() == 'false': + for directive in self.cython_directives.split(","): + k, _, v = directive.partition("=") + if v.lower() == "false": v = False - if v.lower() == 'true': + if v.lower() == "true": v = True directives[k] = v @@ -252,89 +268,95 @@ def finalize_options(self): self.distribution.ext_modules[:] = cythonize( self.distribution.ext_modules, compiler_directives=directives, - annotate=self.cython_annotate) + annotate=self.cython_annotate, + ) - super(build_ext, self).finalize_options() + super().finalize_options() INCLUDE_DIRS = [ - 'edgedb/pgproto/', - 'edgedb/datatypes', + "edgedb/pgproto/", + "edgedb/datatypes", ] setup_requires = [] -if (not (_ROOT / 'edgedb' / 'protocol' / 'protocol.c').exists() or - '--cython-always' in sys.argv): +if ( + not (_ROOT / "edgedb" / "protocol" / "protocol.c").exists() + or "--cython-always" in sys.argv +): # No Cython output, require Cython to build. setup_requires.append(CYTHON_DEPENDENCY) -with open(str(_ROOT / 'README.rst')) as f: +with open(str(_ROOT / "README.rst")) as f: readme = f.read() setuptools.setup( - name='edgedb', + name="edgedb", version=VERSION, - description='EdgeDB Python driver', + description="EdgeDB Python driver", long_description=readme, - platforms=['macOS', 'POSIX', 'Windows'], - author='MagicStack Inc', - author_email='hello@magic.io', - url='https://github.com/edgedb/edgedb-python', - license='Apache License, Version 2.0', + platforms=["macOS", "POSIX", "Windows"], + author="MagicStack Inc", + author_email="hello@magic.io", + url="https://github.com/edgedb/edgedb-python", + license="Apache License, Version 2.0", packages=setuptools.find_packages(), - provides=['edgedb'], + provides=["edgedb"], zip_safe=False, include_package_data=True, - package_data={'edgedb': ['py.typed']}, + package_data={"edgedb": ["py.typed"]}, ext_modules=[ distutils_extension.Extension( "edgedb.pgproto.pgproto", ["edgedb/pgproto/pgproto.pyx"], extra_compile_args=CFLAGS, - extra_link_args=LDFLAGS), - + extra_link_args=LDFLAGS, + ), distutils_extension.Extension( "edgedb.datatypes.datatypes", - ["edgedb/datatypes/args.c", - "edgedb/datatypes/record_desc.c", - "edgedb/datatypes/namedtuple.c", - "edgedb/datatypes/object.c", - "edgedb/datatypes/hash.c", - "edgedb/datatypes/link.c", - "edgedb/datatypes/linkset.c", - "edgedb/datatypes/repr.c", - "edgedb/datatypes/comp.c", - "edgedb/datatypes/datatypes.pyx"], + [ + "edgedb/datatypes/args.c", + "edgedb/datatypes/record_desc.c", + "edgedb/datatypes/namedtuple.c", + "edgedb/datatypes/object.c", + "edgedb/datatypes/hash.c", + "edgedb/datatypes/link.c", + "edgedb/datatypes/linkset.c", + "edgedb/datatypes/repr.c", + "edgedb/datatypes/comp.c", + "edgedb/datatypes/datatypes.pyx", + ], extra_compile_args=CFLAGS, - extra_link_args=LDFLAGS), - + extra_link_args=LDFLAGS, + ), distutils_extension.Extension( "edgedb.protocol.protocol", ["edgedb/protocol/protocol.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS, - include_dirs=INCLUDE_DIRS), - + include_dirs=INCLUDE_DIRS, + ), distutils_extension.Extension( "edgedb.protocol.asyncio_proto", ["edgedb/protocol/asyncio_proto.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS, - include_dirs=INCLUDE_DIRS), - + include_dirs=INCLUDE_DIRS, + ), distutils_extension.Extension( "edgedb.protocol.blocking_proto", ["edgedb/protocol/blocking_proto.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS, - include_dirs=INCLUDE_DIRS), + include_dirs=INCLUDE_DIRS, + ), ], - cmdclass={'build_ext': build_ext}, - test_suite='tests.suite', + cmdclass={"build_ext": build_ext}, + test_suite="tests.suite", python_requires=">=3.7", install_requires=[ 'certifi>=2021.5.30; platform_system == "Windows"', @@ -345,5 +367,5 @@ def finalize_options(self): "console_scripts": [ "edgedb-py=edgedb.codegen.cli:main", ] - } + }, ) diff --git a/tests/__init__.py b/tests/__init__.py index e3f19011..20b10435 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -24,12 +24,13 @@ def suite(): test_loader = unittest.TestLoader() - test_suite = test_loader.discover(str(pathlib.Path(__file__).parent), - pattern='test_*.py') + test_suite = test_loader.discover( + str(pathlib.Path(__file__).parent), pattern="test_*.py" + ) return test_suite -if __name__ == '__main__': +if __name__ == "__main__": runner = unittest.runner.TextTestRunner(verbosity=2) result = runner.run(suite()) sys.exit(not result.wasSuccessful()) diff --git a/tests/bench_uuid.py b/tests/bench_uuid.py index 5323bb51..b70f17bb 100644 --- a/tests/bench_uuid.py +++ b/tests/bench_uuid.py @@ -19,8 +19,9 @@ assert issubclass(c_UUID, std_UUID) assert isinstance(c_UUID(ubytes), std_UUID) -assert c_UUID(ubytes).bytes == std_UUID(bytes=ubytes).bytes, \ - f'{ubytes}: {c_UUID(ubytes).bytes}' +assert ( + c_UUID(ubytes).bytes == std_UUID(bytes=ubytes).bytes +), f"{ubytes}: {c_UUID(ubytes).bytes}" assert c_UUID(ubytes).hex == std_UUID(bytes=ubytes).hex assert c_UUID(ubytes).int == std_UUID(bytes=ubytes).int assert c_UUID(str(std_UUID(bytes=ubytes))).int == std_UUID(bytes=ubytes).int @@ -40,18 +41,18 @@ for _ in range(N): std_UUID(bytes=ubytes) std_total = time.monotonic() - st -print(f'std_UUID(bytes):\t {std_total:.4f}') +print(f"std_UUID(bytes):\t {std_total:.4f}") st = time.monotonic() for _ in range(N): c_UUID(ubytes) c_total = time.monotonic() - st -print(f'c_UUID(bytes):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"c_UUID(bytes):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") st = time.monotonic() for _ in range(N): object() -print(f'object():\t\t {time.monotonic() - st:.4f}') +print(f"object():\t\t {time.monotonic() - st:.4f}") print() @@ -60,13 +61,13 @@ for _ in range(N): std_UUID(ustr) std_total = time.monotonic() - st -print(f'std_UUID(str):\t\t {std_total:.4f}') +print(f"std_UUID(str):\t\t {std_total:.4f}") st = time.monotonic() for _ in range(N): c_UUID(ustr) c_total = time.monotonic() - st -print(f'c_UUID(str):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"c_UUID(str):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") print() @@ -76,20 +77,20 @@ for _ in range(N): str(u) std_total = time.monotonic() - st -print(f'str(std_UUID()):\t {std_total:.4f}') +print(f"str(std_UUID()):\t {std_total:.4f}") u = c_UUID(ubytes) st = time.monotonic() for _ in range(N): str(u) c_total = time.monotonic() - st -print(f'str(c_UUID()):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"str(c_UUID()):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") u = object() st = time.monotonic() for _ in range(N): str(u) -print(f'str(object()):\t\t {time.monotonic() - st:.4f}') +print(f"str(object()):\t\t {time.monotonic() - st:.4f}") print() @@ -99,7 +100,7 @@ for _ in range(N): u.bytes std_total = time.monotonic() - st -print(f'std_UUID().bytes:\t {std_total:.4f}') +print(f"std_UUID().bytes:\t {std_total:.4f}") u = c_UUID(ubytes) @@ -107,7 +108,7 @@ for _ in range(N): u.bytes c_total = time.monotonic() - st -print(f'c_UUID().bytes:\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"c_UUID().bytes:\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") print() @@ -116,7 +117,7 @@ for _ in range(N): u.int std_total = time.monotonic() - st -print(f'std_UUID().int:\t\t {std_total:.4f}') +print(f"std_UUID().int:\t\t {std_total:.4f}") u = c_UUID(ubytes) @@ -124,7 +125,7 @@ for _ in range(N): u.int c_total = time.monotonic() - st -print(f'c_UUID().int:\t\t* {c_total:.4f}') +print(f"c_UUID().int:\t\t* {c_total:.4f}") print() @@ -133,7 +134,7 @@ for _ in range(N): u.hex std_total = time.monotonic() - st -print(f'std_UUID().hex:\t\t {std_total:.4f}') +print(f"std_UUID().hex:\t\t {std_total:.4f}") u = c_UUID(ubytes) @@ -141,7 +142,7 @@ for _ in range(N): u.hex c_total = time.monotonic() - st -print(f'c_UUID().hex:\t\t* {c_total:.4f}') +print(f"c_UUID().hex:\t\t* {c_total:.4f}") print() @@ -150,7 +151,7 @@ for _ in range(N): hash(u) std_total = time.monotonic() - st -print(f'hash(std_UUID()):\t {std_total:.4f}') +print(f"hash(std_UUID()):\t {std_total:.4f}") u = c_UUID(ubytes) @@ -158,7 +159,7 @@ for _ in range(N): hash(u) c_total = time.monotonic() - st -print(f'hash(c_UUID()):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"hash(c_UUID()):\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") print() @@ -169,7 +170,7 @@ for _ in range(N): dct.get(u) std_total = time.monotonic() - st -print(f'dct[std_UUID()]:\t {std_total:.4f}') +print(f"dct[std_UUID()]:\t {std_total:.4f}") u = c_UUID(ubytes) @@ -177,7 +178,7 @@ for _ in range(N): dct.get(u) c_total = time.monotonic() - st -print(f'dct[c_UUID()]:\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"dct[c_UUID()]:\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") print() @@ -186,7 +187,7 @@ for _ in range(N): _ = u == TEST_UUID std_total = time.monotonic() - st -print(f'std_UUID() ==:\t\t {std_total:.4f}') +print(f"std_UUID() ==:\t\t {std_total:.4f}") u = c_UUID(ubytes) @@ -194,4 +195,4 @@ for _ in range(N): _ = u == TEST_CUUID c_total = time.monotonic() - st -print(f'c_UUID() ==:\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)') +print(f"c_UUID() ==:\t\t* {c_total:.4f} ({std_total / c_total:.2f}x)") diff --git a/tests/codegen/test-project1/linked b/tests/codegen/test-project1/linked index 4433438d..62d256ba 120000 --- a/tests/codegen/test-project1/linked +++ b/tests/codegen/test-project1/linked @@ -1 +1 @@ -../linked/ \ No newline at end of file +../linked/ diff --git a/tests/datatypes/test_datatypes.py b/tests/datatypes/test_datatypes.py index eaff8aff..8d352ff5 100644 --- a/tests/datatypes/test_datatypes.py +++ b/tests/datatypes/test_datatypes.py @@ -36,97 +36,101 @@ def wrapper(*args, **kwargs): with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) return f(*args, **kwargs) + return wrapper class TestRecordDesc(unittest.TestCase): - def test_recorddesc_1(self): - with self.assertRaisesRegex(TypeError, 'one to three positional'): + with self.assertRaisesRegex(TypeError, "one to three positional"): private._RecordDescriptor() - with self.assertRaisesRegex(TypeError, 'one to three positional'): + with self.assertRaisesRegex(TypeError, "one to three positional"): private._RecordDescriptor(t=1) - with self.assertRaisesRegex(TypeError, 'requires a tuple'): + with self.assertRaisesRegex(TypeError, "requires a tuple"): private._RecordDescriptor(1) - with self.assertRaisesRegex(TypeError, 'requires a tuple'): - private._RecordDescriptor(('a',), 1) + with self.assertRaisesRegex(TypeError, "requires a tuple"): + private._RecordDescriptor(("a",), 1) - with self.assertRaisesRegex(TypeError, - 'the same length as the names tuple'): - private._RecordDescriptor(('a',), ()) + with self.assertRaisesRegex( + TypeError, "the same length as the names tuple" + ): + private._RecordDescriptor(("a",), ()) - private._RecordDescriptor(('a', 'b')) + private._RecordDescriptor(("a", "b")) - with self.assertRaisesRegex(ValueError, f'more than {0x4000-1}'): - private._RecordDescriptor(('a',) * 20000) + with self.assertRaisesRegex(ValueError, f"more than {0x4000-1}"): + private._RecordDescriptor(("a",) * 20000) def test_recorddesc_2(self): rd = private._RecordDescriptor( - ('a', 'b', 'c'), - (private._EDGE_POINTER_IS_LINKPROP, - 0, - private._EDGE_POINTER_IS_LINK)) + ("a", "b", "c"), + ( + private._EDGE_POINTER_IS_LINKPROP, + 0, + private._EDGE_POINTER_IS_LINK, + ), + ) - self.assertEqual(rd.get_pos('a'), 0) - self.assertEqual(rd.get_pos('b'), 1) - self.assertEqual(rd.get_pos('c'), 2) + self.assertEqual(rd.get_pos("a"), 0) + self.assertEqual(rd.get_pos("b"), 1) + self.assertEqual(rd.get_pos("c"), 2) - self.assertTrue(rd.is_linkprop('a')) - self.assertFalse(rd.is_linkprop('b')) - self.assertFalse(rd.is_linkprop('c')) + self.assertTrue(rd.is_linkprop("a")) + self.assertFalse(rd.is_linkprop("b")) + self.assertFalse(rd.is_linkprop("c")) - self.assertFalse(rd.is_link('a')) - self.assertFalse(rd.is_link('b')) - self.assertTrue(rd.is_link('c')) + self.assertFalse(rd.is_link("a")) + self.assertFalse(rd.is_link("b")) + self.assertTrue(rd.is_link("c")) with self.assertRaises(LookupError): - rd.get_pos('z') + rd.get_pos("z") with self.assertRaises(LookupError): - rd.is_linkprop('z') + rd.is_linkprop("z") def test_recorddesc_3(self): f = private.create_object_factory( - id={'property', 'implicit'}, - lb='link-property', - c='property', - d='link', + id={"property", "implicit"}, + lb="link-property", + c="property", + d="link", ) o = f(1, 2, 3, 4) desc = private.get_object_descriptor(o) - self.assertEqual(set(dir(desc)), set(('id', '@lb', 'c', 'd'))) + self.assertEqual(set(dir(desc)), set(("id", "@lb", "c", "d"))) - self.assertTrue(desc.is_linkprop('@lb')) - self.assertFalse(desc.is_linkprop('id')) - self.assertFalse(desc.is_linkprop('c')) - self.assertFalse(desc.is_linkprop('d')) + self.assertTrue(desc.is_linkprop("@lb")) + self.assertFalse(desc.is_linkprop("id")) + self.assertFalse(desc.is_linkprop("c")) + self.assertFalse(desc.is_linkprop("d")) - self.assertFalse(desc.is_link('@lb')) - self.assertFalse(desc.is_link('id')) - self.assertFalse(desc.is_link('c')) - self.assertTrue(desc.is_link('d')) + self.assertFalse(desc.is_link("@lb")) + self.assertFalse(desc.is_link("id")) + self.assertFalse(desc.is_link("c")) + self.assertTrue(desc.is_link("d")) - self.assertFalse(desc.is_implicit('@lb')) - self.assertTrue(desc.is_implicit('id')) - self.assertFalse(desc.is_implicit('c')) - self.assertFalse(desc.is_implicit('d')) + self.assertFalse(desc.is_implicit("@lb")) + self.assertTrue(desc.is_implicit("id")) + self.assertFalse(desc.is_implicit("c")) + self.assertFalse(desc.is_implicit("d")) - self.assertEqual(desc.get_pos('@lb'), 1) - self.assertEqual(desc.get_pos('id'), 0) - self.assertEqual(desc.get_pos('c'), 2) - self.assertEqual(desc.get_pos('d'), 3) + self.assertEqual(desc.get_pos("@lb"), 1) + self.assertEqual(desc.get_pos("id"), 0) + self.assertEqual(desc.get_pos("c"), 2) + self.assertEqual(desc.get_pos("d"), 3) def test_recorddesc_4(self): f = private.create_object_factory( - id={'property', 'implicit'}, - lb='link-property', - c='property', - d='link', + id={"property", "implicit"}, + lb="link-property", + c="property", + d="link", ) o = f(1, 2, 3, 4) @@ -135,10 +139,10 @@ def test_recorddesc_4(self): self.assertEqual( intro.pointers, ( - ('id', introspect.ElementKind.PROPERTY, True), - ('c', introspect.ElementKind.PROPERTY, False), - ('d', introspect.ElementKind.LINK, False), - ) + ("id", introspect.ElementKind.PROPERTY, True), + ("c", introspect.ElementKind.PROPERTY, False), + ("d", introspect.ElementKind.LINK, False), + ), ) # clear cache so that tests in refcount mode don't freak out. @@ -146,26 +150,25 @@ def test_recorddesc_4(self): class TestTuple(unittest.TestCase): - def test_tuple_empty_1(self): t = edgedb.Tuple() self.assertIsInstance(t, tuple) self.assertEqual(len(t), 0) self.assertEqual(hash(t), hash(())) - self.assertEqual(repr(t), '()') - with self.assertRaisesRegex(IndexError, 'out of range'): + self.assertEqual(repr(t), "()") + with self.assertRaisesRegex(IndexError, "out of range"): t[0] def test_tuple_2(self): - t = edgedb.Tuple((1, 'a')) + t = edgedb.Tuple((1, "a")) self.assertEqual(len(t), 2) - self.assertEqual(hash(t), hash((1, 'a'))) + self.assertEqual(hash(t), hash((1, "a"))) self.assertEqual(repr(t), "(1, 'a')") self.assertEqual(t[0], 1) - self.assertEqual(t[1], 'a') - with self.assertRaisesRegex(IndexError, 'out of range'): + self.assertEqual(t[1], "a") + with self.assertRaisesRegex(IndexError, "out of range"): t[2] def test_tuple_3(self): @@ -173,8 +176,8 @@ def test_tuple_3(self): t[1].append(t) self.assertEqual(t[1], [t]) - self.assertEqual(repr(t), '(1, [(...)])') - self.assertEqual(str(t), '(1, [(...)])') + self.assertEqual(repr(t), "(1, [(...)])") + self.assertEqual(str(t), "(1, [(...)])") def test_tuple_freelist_1(self): lst = [] @@ -184,139 +187,86 @@ def test_tuple_freelist_1(self): self.assertEqual(t[0], 1) def test_tuple_5(self): - self.assertEqual( - edgedb.Tuple([1, 2, 3]), - edgedb.Tuple([1, 2, 3])) + self.assertEqual(edgedb.Tuple([1, 2, 3]), edgedb.Tuple([1, 2, 3])) - self.assertNotEqual( - edgedb.Tuple([1, 2, 3]), - edgedb.Tuple([1, 3, 2])) + self.assertNotEqual(edgedb.Tuple([1, 2, 3]), edgedb.Tuple([1, 3, 2])) - self.assertLess( - edgedb.Tuple([1, 2, 3]), - edgedb.Tuple([1, 3, 2])) + self.assertLess(edgedb.Tuple([1, 2, 3]), edgedb.Tuple([1, 3, 2])) - self.assertEqual( - edgedb.Tuple([]), - edgedb.Tuple([])) + self.assertEqual(edgedb.Tuple([]), edgedb.Tuple([])) - self.assertEqual( - edgedb.Tuple([1]), - edgedb.Tuple([1])) + self.assertEqual(edgedb.Tuple([1]), edgedb.Tuple([1])) - self.assertGreaterEqual( - edgedb.Tuple([1]), - edgedb.Tuple([1])) + self.assertGreaterEqual(edgedb.Tuple([1]), edgedb.Tuple([1])) - self.assertNotEqual( - edgedb.Tuple([1]), - edgedb.Tuple([])) + self.assertNotEqual(edgedb.Tuple([1]), edgedb.Tuple([])) - self.assertGreater( - edgedb.Tuple([1]), - edgedb.Tuple([])) + self.assertGreater(edgedb.Tuple([1]), edgedb.Tuple([])) - self.assertNotEqual( - edgedb.Tuple([1]), - edgedb.Tuple([2])) + self.assertNotEqual(edgedb.Tuple([1]), edgedb.Tuple([2])) - self.assertLess( - edgedb.Tuple([1]), - edgedb.Tuple([2])) + self.assertLess(edgedb.Tuple([1]), edgedb.Tuple([2])) - self.assertNotEqual( - edgedb.Tuple([1, 2]), - edgedb.Tuple([2, 2])) + self.assertNotEqual(edgedb.Tuple([1, 2]), edgedb.Tuple([2, 2])) - self.assertNotEqual( - edgedb.Tuple([1, 1]), - edgedb.Tuple([2, 2, 1])) + self.assertNotEqual(edgedb.Tuple([1, 1]), edgedb.Tuple([2, 2, 1])) def test_tuple_6(self): - self.assertEqual( - edgedb.Tuple([1, 2, 3]), - (1, 2, 3)) + self.assertEqual(edgedb.Tuple([1, 2, 3]), (1, 2, 3)) - self.assertEqual( - (1, 2, 3), - edgedb.Tuple([1, 2, 3])) + self.assertEqual((1, 2, 3), edgedb.Tuple([1, 2, 3])) - self.assertNotEqual( - edgedb.Tuple([1, 2, 3]), - (1, 3, 2)) + self.assertNotEqual(edgedb.Tuple([1, 2, 3]), (1, 3, 2)) - self.assertLess( - edgedb.Tuple([1, 2, 3]), - (1, 3, 2)) + self.assertLess(edgedb.Tuple([1, 2, 3]), (1, 3, 2)) - self.assertEqual( - edgedb.Tuple([]), - ()) + self.assertEqual(edgedb.Tuple([]), ()) - self.assertEqual( - edgedb.Tuple([1]), - (1,)) + self.assertEqual(edgedb.Tuple([1]), (1,)) - self.assertGreaterEqual( - edgedb.Tuple([1]), - (1,)) + self.assertGreaterEqual(edgedb.Tuple([1]), (1,)) - self.assertNotEqual( - edgedb.Tuple([1]), - ()) + self.assertNotEqual(edgedb.Tuple([1]), ()) - self.assertGreater( - edgedb.Tuple([1]), - ()) + self.assertGreater(edgedb.Tuple([1]), ()) - self.assertNotEqual( - edgedb.Tuple([1]), - (2,)) + self.assertNotEqual(edgedb.Tuple([1]), (2,)) - self.assertLess( - edgedb.Tuple([1]), - (2,)) + self.assertLess(edgedb.Tuple([1]), (2,)) - self.assertNotEqual( - edgedb.Tuple([1, 2]), - (2, 2)) + self.assertNotEqual(edgedb.Tuple([1, 2]), (2, 2)) - self.assertNotEqual( - edgedb.Tuple([1, 1]), - (2, 2, 1)) + self.assertNotEqual(edgedb.Tuple([1, 1]), (2, 2, 1)) def test_tuple_7(self): - self.assertNotEqual( - edgedb.Tuple([1, 2, 3]), - 123) + self.assertNotEqual(edgedb.Tuple([1, 2, 3]), 123) class TestNamedTuple(unittest.TestCase): - def test_namedtuple_empty_1(self): - with self.assertRaisesRegex(ValueError, 'at least one field'): + with self.assertRaisesRegex(ValueError, "at least one field"): edgedb.NamedTuple() def test_namedtuple_2(self): t = edgedb.NamedTuple(a=1) self.assertEqual(repr(t), "(a := 1)") - t = edgedb.NamedTuple(a=1, b='a') + t = edgedb.NamedTuple(a=1, b="a") - self.assertEqual(set(dir(t)), {'a', 'b'}) + self.assertEqual(set(dir(t)), {"a", "b"}) self.assertEqual(repr(t), "(a := 1, b := 'a')") self.assertEqual(t[0], 1) - self.assertEqual(t[1], 'a') - with self.assertRaisesRegex(IndexError, 'out of range'): + self.assertEqual(t[1], "a") + with self.assertRaisesRegex(IndexError, "out of range"): t[2] self.assertEqual(len(t), 2) - self.assertEqual(hash(t), hash((1, 'a'))) + self.assertEqual(hash(t), hash((1, "a"))) self.assertEqual(t.a, 1) - self.assertEqual(t.b, 'a') + self.assertEqual(t.b, "a") with self.assertRaises(AttributeError): t.z @@ -326,72 +276,52 @@ def test_namedtuple_3(self): t.b.append(t) self.assertEqual(t.b, [t]) - self.assertEqual(repr(t), '(a := 1, b := [(...)])') - self.assertEqual(str(t), '(a := 1, b := [(...)])') + self.assertEqual(repr(t), "(a := 1, b := [(...)])") + self.assertEqual(str(t), "(a := 1, b := [(...)])") def test_namedtuple_4(self): - t1 = edgedb.NamedTuple(a=1, b='aaaa') - t2 = edgedb.Tuple((1, 'aaaa')) - t3 = (1, 'aaaa') + t1 = edgedb.NamedTuple(a=1, b="aaaa") + t2 = edgedb.Tuple((1, "aaaa")) + t3 = (1, "aaaa") self.assertEqual(hash(t1), hash(t2)) self.assertEqual(hash(t1), hash(t3)) def test_namedtuple_5(self): self.assertEqual( - edgedb.NamedTuple(a=1, b=2, c=3), - edgedb.NamedTuple(x=1, y=2, z=3)) + edgedb.NamedTuple(a=1, b=2, c=3), edgedb.NamedTuple(x=1, y=2, z=3) + ) self.assertNotEqual( - edgedb.NamedTuple(a=1, b=2, c=3), - edgedb.NamedTuple(a=1, c=3, b=2)) + edgedb.NamedTuple(a=1, b=2, c=3), edgedb.NamedTuple(a=1, c=3, b=2) + ) self.assertLess( - edgedb.NamedTuple(a=1, b=2, c=3), - edgedb.NamedTuple(a=1, b=3, c=2)) + edgedb.NamedTuple(a=1, b=2, c=3), edgedb.NamedTuple(a=1, b=3, c=2) + ) - self.assertEqual( - edgedb.NamedTuple(a=1), - edgedb.NamedTuple(b=1)) + self.assertEqual(edgedb.NamedTuple(a=1), edgedb.NamedTuple(b=1)) - self.assertEqual( - edgedb.NamedTuple(a=1), - edgedb.NamedTuple(a=1)) + self.assertEqual(edgedb.NamedTuple(a=1), edgedb.NamedTuple(a=1)) def test_namedtuple_6(self): - self.assertEqual( - edgedb.NamedTuple(a=1, b=2, c=3), - (1, 2, 3)) + self.assertEqual(edgedb.NamedTuple(a=1, b=2, c=3), (1, 2, 3)) - self.assertEqual( - (1, 2, 3), - edgedb.NamedTuple(a=1, b=2, c=3)) + self.assertEqual((1, 2, 3), edgedb.NamedTuple(a=1, b=2, c=3)) - self.assertNotEqual( - edgedb.NamedTuple(a=1, b=2, c=3), - (1, 3, 2)) + self.assertNotEqual(edgedb.NamedTuple(a=1, b=2, c=3), (1, 3, 2)) - self.assertLess( - edgedb.NamedTuple(a=1, b=2, c=3), - (1, 3, 2)) + self.assertLess(edgedb.NamedTuple(a=1, b=2, c=3), (1, 3, 2)) - self.assertEqual( - edgedb.NamedTuple(a=1), - (1,)) + self.assertEqual(edgedb.NamedTuple(a=1), (1,)) - self.assertEqual( - edgedb.NamedTuple(a=1), - (1,)) + self.assertEqual(edgedb.NamedTuple(a=1), (1,)) def test_namedtuple_7(self): - self.assertNotEqual( - edgedb.NamedTuple(a=1, b=2, c=3), - 1) + self.assertNotEqual(edgedb.NamedTuple(a=1, b=2, c=3), 1) def test_namedtuple_8(self): - self.assertEqual( - edgedb.NamedTuple(壹=1, 贰=2, 叄=3), - (1, 2, 3)) + self.assertEqual(edgedb.NamedTuple(壹=1, 贰=2, 叄=3), (1, 2, 3)) def test_namedtuple_memory(self): num = int(os.getenv("EDGEDB_PYTHON_TEST_NAMEDTUPLE_MEMORY", 100)) @@ -411,7 +341,7 @@ def test(): fix_tp(random.randint(10, 20), random.randint(20, 30)) ) if len(nt) % random.randint(10, 20) == 0: - nt[:] = nt[random.randint(5, len(nt)):] + nt[:] = nt[random.randint(5, len(nt)) :] gc.collect() gc.collect() @@ -485,17 +415,14 @@ def test_derived_namedtuple_4(self): class TestObject(unittest.TestCase): - def test_object_1(self): f = private.create_object_factory( - id='property', - lb='link-property', - c='property' + id="property", lb="link-property", c="property" ) o = f(1, 2, 3) - self.assertEqual(repr(o), 'Object{id := 1, @lb := 2, c := 3}') + self.assertEqual(repr(o), "Object{id := 1, @lb := 2, c := 3}") self.assertEqual(o.id, 1) self.assertEqual(o.c, 3) @@ -513,101 +440,91 @@ def test_object_1(self): o[0] with self.assertRaises(TypeError): - o['id'] + o["id"] - self.assertEqual(set(dir(o)), {'id', 'c'}) + self.assertEqual(set(dir(o)), {"id", "c"}) def test_object_2(self): f = private.create_object_factory( - id={'property', 'implicit'}, - lb='link-property', - c='property' + id={"property", "implicit"}, lb="link-property", c="property" ) o = f(1, 2, 3) - self.assertEqual(repr(o), 'Object{@lb := 2, c := 3}') + self.assertEqual(repr(o), "Object{@lb := 2, c := 3}") self.assertNotEqual(hash(o), hash(f(1, 2, 3))) - self.assertNotEqual(hash(o), hash(f(1, 2, 'aaaa'))) + self.assertNotEqual(hash(o), hash(f(1, 2, "aaaa"))) self.assertNotEqual(hash(o), hash((1, 2, 3))) - self.assertEqual(set(dir(o)), {'id', 'c'}) + self.assertEqual(set(dir(o)), {"id", "c"}) def test_object_3(self): - f = private.create_object_factory(id='property', c='link') + f = private.create_object_factory(id="property", c="link") o = f(1, []) o.c.append(o) - self.assertEqual(repr(o), 'Object{id := 1, c := [Object{...}]}') + self.assertEqual(repr(o), "Object{id := 1, c := [Object{...}]}") def test_object_4(self): f = private.create_object_factory( - id={'property', 'implicit'}, - lb='link-property', - c='property' + id={"property", "implicit"}, lb="link-property", c="property" ) - o1 = f(1, 'aa', 'ba') - o2 = f(1, 'ab', 'bb') - o3 = f(3, 'ac', 'bc') + o1 = f(1, "aa", "ba") + o2 = f(1, "ab", "bb") + o3 = f(3, "ac", "bc") self.assertNotEqual(o1, o2) self.assertNotEqual(o1, o3) def test_object_5(self): f = private.create_object_factory( - a='property', - lb='link-property', - c='property' + a="property", lb="link-property", c="property" ) with self.assertRaisesRegex(ValueError, "without 'id' field"): f(1, 2, 3) def test_object_6(self): User = private.create_object_factory( - id='property', - name='property', + id="property", + name="property", ) - u = User(1, 'user1') + u = User(1, "user1") - with self.assertRaisesRegex(TypeError, - "property 'name' should be " - "accessed via dot notation"): - u['name'] + with self.assertRaisesRegex( + TypeError, "property 'name' should be " "accessed via dot notation" + ): + u["name"] @test_deprecated def test_object_links_1(self): O2 = private.create_object_factory( - id='property', - lb='link-property', - c='property' + id="property", lb="link-property", c="property" ) - O1 = private.create_object_factory( - id='property', - o2s='link' - ) + O1 = private.create_object_factory(id="property", o2s="link") - o2_1 = O2(1, 'linkprop o2 1', 3) - o2_2 = O2(4, 'linkprop o2 2', 6) + o2_1 = O2(1, "linkprop o2 1", 3) + o2_2 = O2(4, "linkprop o2 2", 6) o1 = O1(2, edgedb.Set((o2_1, o2_2))) - linkset = o1['o2s'] + linkset = o1["o2s"] self.assertEqual(len(linkset), 2) - self.assertEqual(linkset, o1['o2s']) + self.assertEqual(linkset, o1["o2s"]) self.assertEqual( repr(linkset), - "LinkSet(name='o2s', source_id=2, target_ids={1, 4})") + "LinkSet(name='o2s', source_id=2, target_ids={1, 4})", + ) link1 = linkset[0] self.assertIs(link1.source, o1) self.assertIs(link1.target, o2_1) self.assertEqual( - repr(link1), - "Link(name='o2s', source_id=2, target_id=1)") - self.assertEqual(set(dir(link1)), {'target', 'source', 'lb'}) + repr(link1), "Link(name='o2s', source_id=2, target_id=1)" + ) + self.assertEqual(set(dir(link1)), {"target", "source", "lb"}) link2 = linkset[1] self.assertIs(link2.source, o1) @@ -620,8 +537,8 @@ def test_object_links_1(self): self.assertNotEqual(link1, link2) - self.assertEqual(link1.lb, 'linkprop o2 1') - self.assertEqual(link2.lb, 'linkprop o2 2') + self.assertEqual(link1.lb, "linkprop o2 1") + self.assertEqual(link2.lb, "linkprop o2 2") with self.assertRaises(AttributeError): link2.aaaa @@ -629,9 +546,9 @@ def test_object_links_1(self): @test_deprecated def test_object_links_2(self): User = private.create_object_factory( - id='property', - friends='link', - enemies='link', + id="property", + friends="link", + enemies="link", ) u1 = User(1, edgedb.Set([]), edgedb.Set([])) @@ -640,42 +557,43 @@ def test_object_links_2(self): u4 = User(4, edgedb.Set([u1, u2]), edgedb.Set([u1, u2])) u5 = User(5, edgedb.Set([u1, u3]), edgedb.Set([u1, u2])) - self.assertNotEqual(u4['friends'], u4['enemies']) - self.assertNotEqual(u4['enemies'], u5['enemies']) + self.assertNotEqual(u4["friends"], u4["enemies"]) + self.assertNotEqual(u4["enemies"], u5["enemies"]) - self.assertEqual(set(dir(u1)), {'id', 'friends', 'enemies'}) + self.assertEqual(set(dir(u1)), {"id", "friends", "enemies"}) @test_deprecated def test_object_links_3(self): User = private.create_object_factory( - id='property', - friend='link', + id="property", + friend="link", ) u1 = User(1, None) u2 = User(2, u1) u3 = User(3, edgedb.Set([])) - self.assertEqual(set(dir(u2['friend'])), {'source', 'target'}) + self.assertEqual(set(dir(u2["friend"])), {"source", "target"}) - self.assertIs(u2['friend'].target, u1) + self.assertIs(u2["friend"].target, u1) - self.assertIsNone(u1['friend']) + self.assertIsNone(u1["friend"]) - self.assertEqual(len(u3['friend']), 0) + self.assertEqual(len(u3["friend"]), 0) self.assertEqual( - repr(u3['friend']), - "LinkSet(name='friend', source_id=3, target_ids={})") + repr(u3["friend"]), + "LinkSet(name='friend', source_id=3, target_ids={})", + ) self.assertEqual( - repr(u2['friend']), - "Link(name='friend', source_id=2, target_id=1)") + repr(u2["friend"]), "Link(name='friend', source_id=2, target_id=1)" + ) @test_deprecated def test_object_links_4(self): User = private.create_object_factory( - id='property', - friend='link', + id="property", + friend="link", ) u = User(1, None) @@ -683,22 +601,17 @@ def test_object_links_4(self): with self.assertRaisesRegex( KeyError, "link property '@error_key' does not exist" ): - u['@error_key'] + u["@error_key"] def test_object_link_property_1(self): O2 = private.create_object_factory( - id='property', - lb='link-property', - c='property' + id="property", lb="link-property", c="property" ) - O1 = private.create_object_factory( - id='property', - o2s='link' - ) + O1 = private.create_object_factory(id="property", o2s="link") - o2_1 = O2(1, 'linkprop o2 1', 3) - o2_2 = O2(4, 'linkprop o2 2', 6) + o2_1 = O2(1, "linkprop o2 1", 3) + o2_2 = O2(4, "linkprop o2 2", 6) o1 = O1(2, edgedb.Set((o2_1, o2_2))) o2s = o1.o2s @@ -707,13 +620,13 @@ def test_object_link_property_1(self): self.assertEqual( repr(o2s), "[Object{id := 1, @lb := 'linkprop o2 1', c := 3}," - " Object{id := 4, @lb := 'linkprop o2 2', c := 6}]" + " Object{id := 4, @lb := 'linkprop o2 2', c := 6}]", ) - self.assertEqual(o2s[0]['@lb'], 'linkprop o2 1') - self.assertEqual(o2s[1]['@lb'], 'linkprop o2 2') - self.assertEqual(getattr(o2s[0], '@lb'), 'linkprop o2 1') - self.assertEqual(getattr(o2s[1], '@lb'), 'linkprop o2 2') + self.assertEqual(o2s[0]["@lb"], "linkprop o2 1") + self.assertEqual(o2s[1]["@lb"], "linkprop o2 2") + self.assertEqual(getattr(o2s[0], "@lb"), "linkprop o2 1") + self.assertEqual(getattr(o2s[1], "@lb"), "linkprop o2 2") with self.assertRaises(AttributeError): o2s[0].lb @@ -725,31 +638,31 @@ def test_object_link_property_1(self): TypeError, "link property 'lb' should be accessed with '@' prefix", ): - o2s[0]['lb'] + o2s[0]["lb"] with self.assertRaisesRegex( TypeError, "property 'c' should be accessed via dot notation" ): - o2s[0]['c'] + o2s[0]["c"] with self.assertRaisesRegex( KeyError, "link property '@c' does not exist" ): - o2s[0]['@c'] + o2s[0]["@c"] def test_object_dataclass_1(self): User = private.create_object_factory( - id='property', - name='property', - tuple='property', - namedtuple='property', + id="property", + name="property", + tuple="property", + namedtuple="property", linkprop="link-property", ) u = User( 1, - 'Bob', - edgedb.Tuple((1, 2.0, '3')), + "Bob", + edgedb.Tuple((1, 2.0, "3")), edgedb.NamedTuple(a=1, b="Y"), 123, ) @@ -757,21 +670,20 @@ def test_object_dataclass_1(self): self.assertEqual( dataclasses.asdict(u), { - 'id': 1, - 'name': 'Bob', - 'tuple': (1, 2.0, '3'), - 'namedtuple': (1, "Y"), + "id": 1, + "name": "Bob", + "tuple": (1, 2.0, "3"), + "namedtuple": (1, "Y"), }, ) class TestSet(unittest.TestCase): - def test_set_1(self): s = edgedb.Set(()) - self.assertEqual(repr(s), '[]') + self.assertEqual(repr(s), "[]") - s = edgedb.Set((1, 2, [], 'a')) + s = edgedb.Set((1, 2, [], "a")) self.assertEqual(s[1], 2) self.assertEqual(s[2], []) @@ -780,7 +692,7 @@ def test_set_1(self): s[10] def test_set_2(self): - s = edgedb.Set((1, 2, 3000, 'a')) + s = edgedb.Set((1, 2, 3000, "a")) self.assertEqual(repr(s), "[1, 2, 3000, 'a']") @@ -795,69 +707,41 @@ def test_set_4(self): self.assertEqual(repr(s), "[[[...]]]") def test_set_5(self): - self.assertNotEqual( - edgedb.Set([1, 2, 3]), - edgedb.Set([3, 2, 1])) + self.assertNotEqual(edgedb.Set([1, 2, 3]), edgedb.Set([3, 2, 1])) - self.assertEqual( - edgedb.Set([]), - edgedb.Set([])) + self.assertEqual(edgedb.Set([]), edgedb.Set([])) - self.assertEqual( - edgedb.Set([1]), - edgedb.Set([1])) + self.assertEqual(edgedb.Set([1]), edgedb.Set([1])) - self.assertNotEqual( - edgedb.Set([1]), - edgedb.Set([])) + self.assertNotEqual(edgedb.Set([1]), edgedb.Set([])) - self.assertNotEqual( - edgedb.Set([1]), - edgedb.Set([2])) + self.assertNotEqual(edgedb.Set([1]), edgedb.Set([2])) - self.assertNotEqual( - edgedb.Set([1, 2]), - edgedb.Set([2, 2])) + self.assertNotEqual(edgedb.Set([1, 2]), edgedb.Set([2, 2])) - self.assertNotEqual( - edgedb.Set([1, 1, 2]), - edgedb.Set([2, 2, 1])) + self.assertNotEqual(edgedb.Set([1, 1, 2]), edgedb.Set([2, 2, 1])) def test_set_6(self): f = private.create_object_factory( - id={'property', 'implicit'}, - lb='link-property', - c='property' + id={"property", "implicit"}, lb="link-property", c="property" ) - o1 = f(1, 'aa', edgedb.Set([1, 2, 3])) - o2 = f(1, 'ab', edgedb.Set([1, 2, 4])) - o3 = f(3, 'ac', edgedb.Set([5, 5, 5, 5])) + o1 = f(1, "aa", edgedb.Set([1, 2, 3])) + o2 = f(1, "ab", edgedb.Set([1, 2, 4])) + o3 = f(3, "ac", edgedb.Set([5, 5, 5, 5])) - self.assertNotEqual( - edgedb.Set([o1, o2, o3]), - edgedb.Set([o2, o3, o1])) + self.assertNotEqual(edgedb.Set([o1, o2, o3]), edgedb.Set([o2, o3, o1])) - self.assertNotEqual( - edgedb.Set([o1, o3]), - edgedb.Set([o2, o3])) + self.assertNotEqual(edgedb.Set([o1, o3]), edgedb.Set([o2, o3])) - self.assertNotEqual( - edgedb.Set([o1, o1]), - edgedb.Set([o2, o3])) + self.assertNotEqual(edgedb.Set([o1, o1]), edgedb.Set([o2, o3])) def test_set_7(self): - self.assertEqual( - edgedb.Set([1, 2, 3]), - [1, 2, 3]) + self.assertEqual(edgedb.Set([1, 2, 3]), [1, 2, 3]) - self.assertNotEqual( - edgedb.Set([1, 2, 3]), - [3, 2, 1]) + self.assertNotEqual(edgedb.Set([1, 2, 3]), [3, 2, 1]) - self.assertNotEqual( - edgedb.Set([1, 2, 3]), - 1) + self.assertNotEqual(edgedb.Set([1, 2, 3]), 1) def test_set_8(self): s = edgedb.Set([1, 2, 3]) @@ -866,16 +750,15 @@ def test_set_8(self): class TestArray(unittest.TestCase): - def test_array_empty_1(self): t = edgedb.Array() self.assertEqual(len(t), 0) - with self.assertRaisesRegex(IndexError, 'out of range'): + with self.assertRaisesRegex(IndexError, "out of range"): t[0] self.assertEqual(repr(t), "[]") def test_array_2(self): - t = edgedb.Array((1, 'a')) + t = edgedb.Array((1, "a")) self.assertEqual(repr(t), "[1, 'a']") self.assertEqual(str(t), "[1, 'a']") @@ -883,123 +766,69 @@ def test_array_2(self): self.assertEqual(len(t), 2) self.assertEqual(t[0], 1) - self.assertEqual(t[1], 'a') - with self.assertRaisesRegex(IndexError, 'out of range'): + self.assertEqual(t[1], "a") + with self.assertRaisesRegex(IndexError, "out of range"): t[2] def test_array_3(self): t = edgedb.Array((1, [])) t[1].append(t) self.assertEqual(t[1], [t]) - self.assertEqual(repr(t), '[1, [[...]]]') + self.assertEqual(repr(t), "[1, [[...]]]") def test_array_4(self): - self.assertEqual( - edgedb.Array([1, 2, 3]), - edgedb.Array([1, 2, 3])) + self.assertEqual(edgedb.Array([1, 2, 3]), edgedb.Array([1, 2, 3])) - self.assertNotEqual( - edgedb.Array([1, 2, 3]), - edgedb.Array([1, 3, 2])) + self.assertNotEqual(edgedb.Array([1, 2, 3]), edgedb.Array([1, 3, 2])) - self.assertLess( - edgedb.Array([1, 2, 3]), - edgedb.Array([1, 3, 2])) + self.assertLess(edgedb.Array([1, 2, 3]), edgedb.Array([1, 3, 2])) - self.assertEqual( - edgedb.Array([]), - edgedb.Array([])) + self.assertEqual(edgedb.Array([]), edgedb.Array([])) - self.assertEqual( - edgedb.Array([1]), - edgedb.Array([1])) + self.assertEqual(edgedb.Array([1]), edgedb.Array([1])) - self.assertGreaterEqual( - edgedb.Array([1]), - edgedb.Array([1])) + self.assertGreaterEqual(edgedb.Array([1]), edgedb.Array([1])) - self.assertNotEqual( - edgedb.Array([1]), - edgedb.Array([])) + self.assertNotEqual(edgedb.Array([1]), edgedb.Array([])) - self.assertGreater( - edgedb.Array([1]), - edgedb.Array([])) + self.assertGreater(edgedb.Array([1]), edgedb.Array([])) - self.assertNotEqual( - edgedb.Array([1]), - edgedb.Array([2])) + self.assertNotEqual(edgedb.Array([1]), edgedb.Array([2])) - self.assertLess( - edgedb.Array([1]), - edgedb.Array([2])) + self.assertLess(edgedb.Array([1]), edgedb.Array([2])) - self.assertNotEqual( - edgedb.Array([1, 2]), - edgedb.Array([2, 2])) + self.assertNotEqual(edgedb.Array([1, 2]), edgedb.Array([2, 2])) - self.assertNotEqual( - edgedb.Array([1, 1]), - edgedb.Array([2, 2, 1])) + self.assertNotEqual(edgedb.Array([1, 1]), edgedb.Array([2, 2, 1])) def test_array_5(self): - self.assertEqual( - edgedb.Array([1, 2, 3]), - [1, 2, 3]) + self.assertEqual(edgedb.Array([1, 2, 3]), [1, 2, 3]) - self.assertEqual( - [1, 2, 3], - edgedb.Array([1, 2, 3])) + self.assertEqual([1, 2, 3], edgedb.Array([1, 2, 3])) - self.assertNotEqual( - [1, 2, 4], - edgedb.Array([1, 2, 3])) + self.assertNotEqual([1, 2, 4], edgedb.Array([1, 2, 3])) - self.assertNotEqual( - edgedb.Array([1, 2, 3]), - [1, 3, 2]) + self.assertNotEqual(edgedb.Array([1, 2, 3]), [1, 3, 2]) - self.assertLess( - edgedb.Array([1, 2, 3]), - [1, 3, 2]) + self.assertLess(edgedb.Array([1, 2, 3]), [1, 3, 2]) - self.assertEqual( - edgedb.Array([]), - []) + self.assertEqual(edgedb.Array([]), []) - self.assertEqual( - edgedb.Array([1]), - [1]) + self.assertEqual(edgedb.Array([1]), [1]) - self.assertGreaterEqual( - edgedb.Array([1]), - [1]) + self.assertGreaterEqual(edgedb.Array([1]), [1]) - self.assertNotEqual( - edgedb.Array([1]), - []) + self.assertNotEqual(edgedb.Array([1]), []) - self.assertGreater( - edgedb.Array([1]), - []) + self.assertGreater(edgedb.Array([1]), []) - self.assertNotEqual( - edgedb.Array([1]), - [2]) + self.assertNotEqual(edgedb.Array([1]), [2]) - self.assertLess( - edgedb.Array([1]), - [2]) + self.assertLess(edgedb.Array([1]), [2]) - self.assertNotEqual( - edgedb.Array([1, 2]), - [2, 2]) + self.assertNotEqual(edgedb.Array([1, 2]), [2, 2]) - self.assertNotEqual( - edgedb.Array([1, 1]), - [2, 2, 1]) + self.assertNotEqual(edgedb.Array([1, 1]), [2, 2, 1]) def test_array_6(self): - self.assertNotEqual( - edgedb.Array([1, 2, 3]), - False) + self.assertNotEqual(edgedb.Array([1, 2, 3]), False) diff --git a/tests/datatypes/test_uuid.py b/tests/datatypes/test_uuid.py index ca85812f..84cbf307 100644 --- a/tests/datatypes/test_uuid.py +++ b/tests/datatypes/test_uuid.py @@ -26,24 +26,24 @@ from edgedb.protocol.protocol import UUID as c_UUID -special_uuids = frozenset({ - std_UUID('00000000-0000-0000-0000-000000000000'), - std_UUID('00000000-0000-0000-0000-000000000001'), - std_UUID('10000000-0000-0000-0000-000000000000'), - std_UUID('10000000-0000-0000-0000-000000000001'), - std_UUID('FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF'), - std_UUID('0F0F0F0F-0F0F-0F0F-0F0F-0F0F0F0F0F0F'), - std_UUID('F0F0F0F0-F0F0-F0F0-F0F0-F0F0F0F0F0F0'), -}) +special_uuids = frozenset( + { + std_UUID("00000000-0000-0000-0000-000000000000"), + std_UUID("00000000-0000-0000-0000-000000000001"), + std_UUID("10000000-0000-0000-0000-000000000000"), + std_UUID("10000000-0000-0000-0000-000000000001"), + std_UUID("FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF"), + std_UUID("0F0F0F0F-0F0F-0F0F-0F0F-0F0F0F0F0F0F"), + std_UUID("F0F0F0F0-F0F0-F0F0-F0F0-F0F0F0F0F0F0"), + } +) test_uuids = tuple( - special_uuids | - frozenset({uuid.uuid4() for _ in range(1000)}) + special_uuids | frozenset({uuid.uuid4() for _ in range(1000)}) ) class TestUuid(unittest.TestCase): - def ensure_equal(self, uuid1, uuid2): self.assertEqual(uuid1.bytes_le, uuid2.bytes_le) self.assertEqual(uuid1.clock_seq, uuid2.clock_seq) @@ -65,33 +65,36 @@ def ensure_equal(self, uuid1, uuid2): self.assertLessEqual(uuid2, uuid1) def test_uuid_ctr_01(self): - with self.assertRaisesRegex(ValueError, r'invalid UUID.*got 4'): - c_UUID('test') - - with self.assertRaisesRegex(ValueError, - r'invalid UUID.*decodes to less'): - c_UUID('49e3b4e4-4761-11e9-9160-2f38d067497') - - for v in {'49e3b4e4476111e991602f38d067497aaaaa', - '49e3b4e4476111e991602f38d067497aaaa', - '49e3b4e4476111e991602f38d067497aaa', - '49e3b4e4476111e991602f38d067497aa', - '49e3b4e4476111e-991602f-38d067497aa'}: - with self.assertRaisesRegex(ValueError, - r'invalid UUID.*decodes to more'): + with self.assertRaisesRegex(ValueError, r"invalid UUID.*got 4"): + c_UUID("test") + + with self.assertRaisesRegex( + ValueError, r"invalid UUID.*decodes to less" + ): + c_UUID("49e3b4e4-4761-11e9-9160-2f38d067497") + + for v in { + "49e3b4e4476111e991602f38d067497aaaaa", + "49e3b4e4476111e991602f38d067497aaaa", + "49e3b4e4476111e991602f38d067497aaa", + "49e3b4e4476111e991602f38d067497aa", + "49e3b4e4476111e-991602f-38d067497aa", + }: + with self.assertRaisesRegex( + ValueError, r"invalid UUID.*decodes to more" + ): print(c_UUID(v)) - with self.assertRaisesRegex(ValueError, - r"invalid UUID.*unexpected.*'x'"): - c_UUID('49e3b4e4-4761-11e9-9160-2f38dx67497a') + with self.assertRaisesRegex( + ValueError, r"invalid UUID.*unexpected.*'x'" + ): + c_UUID("49e3b4e4-4761-11e9-9160-2f38dx67497a") - with self.assertRaisesRegex(ValueError, - r"invalid UUID.*unexpected"): - c_UUID('49e3b4e4-4761-11160-2f😱3867497a') + with self.assertRaisesRegex(ValueError, r"invalid UUID.*unexpected"): + c_UUID("49e3b4e4-4761-11160-2f😱3867497a") - with self.assertRaisesRegex(ValueError, - r"invalid UUID.*unexpected"): - c_UUID('49e3b4e4-4761-11eE-\xAA60-2f38dx67497a') + with self.assertRaisesRegex(ValueError, r"invalid UUID.*unexpected"): + c_UUID("49e3b4e4-4761-11eE-\xAA60-2f38dx67497a") def test_uuid_ctr_02(self): for py_u in test_uuids: @@ -103,7 +106,7 @@ def test_uuid_ctr_02(self): self.ensure_equal(py_u, c_u) def test_uuid_pickle(self): - u = c_UUID('de197476-4763-11e9-91bf-7311c6dc588e') + u = c_UUID("de197476-4763-11e9-91bf-7311c6dc588e") d = pickle.dumps(u) u2 = pickle.loads(d) self.assertEqual(u, u2) @@ -111,12 +114,12 @@ def test_uuid_pickle(self): self.assertEqual(u.bytes, u2.bytes) def test_uuid_instance(self): - u = c_UUID('de197476-4763-11e9-91bf-7311c6dc588e') + u = c_UUID("de197476-4763-11e9-91bf-7311c6dc588e") self.assertTrue(isinstance(u, uuid.UUID)) self.assertTrue(issubclass(c_UUID, uuid.UUID)) def test_uuid_compare(self): - u = c_UUID('de197476-4763-11e9-91bf-7311c6dc588e') + u = c_UUID("de197476-4763-11e9-91bf-7311c6dc588e") us = uuid.UUID(bytes=u.bytes) for us2 in test_uuids: @@ -129,14 +132,14 @@ def test_uuid_compare(self): self.assertGreater(u, u2) self.assertLess(u2, u) - u3 = c_UUID('de197476-4763-11e9-91bf-7311c6dc588e') + u3 = c_UUID("de197476-4763-11e9-91bf-7311c6dc588e") self.assertTrue(u == u3) self.assertFalse(u != u3) self.assertTrue(u >= u3) self.assertTrue(u <= u3) - a = c_UUID('10000000-0000-0000-0000-000000000001') - b = c_UUID('10000000-0000-0000-0000-000000000000') + a = c_UUID("10000000-0000-0000-0000-000000000001") + b = c_UUID("10000000-0000-0000-0000-000000000000") self.assertGreater(a, b) self.assertLess(b, a) diff --git a/tests/test_async_query.py b/tests/test_async_query.py index 36f3f121..ceca85bd 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -40,18 +40,17 @@ class TestAsyncQuery(tb.AsyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE TYPE test::Tmp { CREATE REQUIRED PROPERTY tmp -> std::str; }; CREATE SCALAR TYPE MyEnum EXTENDING enum<"A", "B">; - ''' + """ - TEARDOWN = ''' + TEARDOWN = """ DROP TYPE test::Tmp; - ''' + """ def setUp(self): super().setUp() @@ -60,68 +59,70 @@ def setUp(self): async def test_async_parse_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.client.query('select syntax error') + await self.client.query("select syntax error") with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.client.query('select syntax error') + await self.client.query("select syntax error") - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): - await self.client.query('select (') + with self.assertRaisesRegex( + edgedb.EdgeQLSyntaxError, "Unexpected end of line" + ): + await self.client.query("select (") - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): - await self.client.query_json('select (') + with self.assertRaisesRegex( + edgedb.EdgeQLSyntaxError, "Unexpected end of line" + ): + await self.client.query_json("select (") for _ in range(10): self.assertEqual( - await self.client.query('select 1;'), - edgedb.Set((1,))) + await self.client.query("select 1;"), edgedb.Set((1,)) + ) self.assertFalse(self.client.connection.is_closed()) async def test_async_parse_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.client.execute('select syntax error') + await self.client.execute("select syntax error") with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.client.execute('select syntax error') + await self.client.execute("select syntax error") for _ in range(10): - await self.client.execute('select 1; select 2;'), + await self.client.execute("select 1; select 2;"), async def test_async_exec_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - await self.client.query('select 1 / 0;') + await self.client.query("select 1 / 0;") with self.assertRaises(edgedb.DivisionByZeroError): - await self.client.query('select 1 / 0;') + await self.client.query("select 1 / 0;") for _ in range(10): self.assertEqual( - await self.client.query('select 1;'), - edgedb.Set((1,))) + await self.client.query("select 1;"), edgedb.Set((1,)) + ) async def test_async_exec_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - await self.client.execute('select 1 / 0;') + await self.client.execute("select 1 / 0;") with self.assertRaises(edgedb.DivisionByZeroError): - await self.client.execute('select 1 / 0;') + await self.client.execute("select 1 / 0;") for _ in range(10): - await self.client.execute('select 1;') + await self.client.execute("select 1;") async def test_async_exec_error_recover_03(self): - query = 'select 10 // $0;' + query = "select 10 // $0;" for i in [1, 2, 0, 3, 1, 0, 1]: if i: self.assertEqual( - await self.client.query(query, i), - edgedb.Set([10 // i])) + await self.client.query(query, i), edgedb.Set([10 // i]) + ) else: with self.assertRaises(edgedb.DivisionByZeroError): await self.client.query(query, i) @@ -129,17 +130,15 @@ async def test_async_exec_error_recover_03(self): async def test_async_exec_error_recover_04(self): for i in [1, 2, 0, 3, 1, 0, 1]: if i: - await self.client.execute(f'select 10 // {i};') + await self.client.execute(f"select 10 // {i};") else: with self.assertRaises(edgedb.DivisionByZeroError): - await self.client.query(f'select 10 // {i};') + await self.client.query(f"select 10 // {i};") async def test_async_exec_error_recover_05(self): with self.assertRaises(edgedb.DivisionByZeroError): - await self.client.execute(f'select 1 / 0') - self.assertEqual( - await self.client.query('SELECT "HELLO"'), - ["HELLO"]) + await self.client.execute("select 1 / 0") + self.assertEqual(await self.client.query('SELECT "HELLO"'), ["HELLO"]) async def test_async_query_single_01(self): res = await self.client.query_single("SELECT 1") @@ -153,188 +152,209 @@ async def test_async_query_single_01(self): await self.client.query_required_single("SELECT {}") async def test_async_query_single_command_01(self): - r = await self.client.query(''' + r = await self.client.query( + """ CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; }; - ''') + """ + ) self.assertEqual(r, []) - r = await self.client.query(''' + r = await self.client.query( + """ DROP TYPE test::server_query_single_command_01; - ''') + """ + ) self.assertEqual(r, []) - r = await self.client.query(''' + r = await self.client.query( + """ CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; }; - ''') + """ + ) self.assertEqual(r, []) - r = await self.client.query_json(''' + r = await self.client.query_json( + """ DROP TYPE test::server_query_single_command_01; - ''') - self.assertEqual(r, '[]') + """ + ) + self.assertEqual(r, "[]") - r = await self.client.query_json(''' + r = await self.client.query_json( + """ CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; }; - ''') - self.assertEqual(r, '[]') + """ + ) + self.assertEqual(r, "[]") with self.assertRaisesRegex( - edgedb.InterfaceError, - r'query cannot be executed with query_required_single_json\('): - await self.client.query_required_single_json(''' + edgedb.InterfaceError, + r"query cannot be executed with query_required_single_json\(", + ): + await self.client.query_required_single_json( + """ DROP TYPE test::server_query_single_command_01; - ''') + """ + ) - r = await self.client.query_json(''' + r = await self.client.query_json( + """ DROP TYPE test::server_query_single_command_01; - ''') - self.assertEqual(r, '[]') + """ + ) + self.assertEqual(r, "[]") self.assertTrue( - self.client.connection._get_last_status().startswith('DROP') + self.client.connection._get_last_status().startswith("DROP") ) async def test_async_query_no_return(self): with self.assertRaisesRegex( - edgedb.InterfaceError, - r'cannot be executed with query_required_single\(\).*' - r'not return'): - await self.client.query_required_single('create type Foo456') + edgedb.InterfaceError, + r"cannot be executed with query_required_single\(\).*" + r"not return", + ): + await self.client.query_required_single("create type Foo456") with self.assertRaisesRegex( - edgedb.InterfaceError, - r'cannot be executed with query_required_single_json\(\).*' - r'not return'): - await self.client.query_required_single_json('create type Bar456') + edgedb.InterfaceError, + r"cannot be executed with query_required_single_json\(\).*" + r"not return", + ): + await self.client.query_required_single_json("create type Bar456") async def test_async_basic_datatypes_01(self): for _ in range(10): - self.assertEqual( - await self.client.query_single( - 'select ()'), - ()) + self.assertEqual(await self.client.query_single("select ()"), ()) self.assertEqual( - await self.client.query( - 'select (1,)'), - edgedb.Set([(1,)])) + await self.client.query("select (1,)"), edgedb.Set([(1,)]) + ) self.assertEqual( - await self.client.query( - 'select ["a", "b"]'), - edgedb.Set([["a", "b"]])) + await self.client.query('select ["a", "b"]'), + edgedb.Set([["a", "b"]]), + ) self.assertEqual( - await self.client.query(''' + await self.client.query( + """ SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; - '''), - edgedb.Set([ - edgedb.NamedTuple(a=42, world=("hello", 32)), - edgedb.NamedTuple(a=1, world=("yo", 10)), - ])) + """ + ), + edgedb.Set( + [ + edgedb.NamedTuple(a=42, world=("hello", 32)), + edgedb.NamedTuple(a=1, world=("yo", 10)), + ] + ), + ) with self.assertRaisesRegex( - edgedb.InterfaceError, - r'query_single\(\) as it may return more than one element' + edgedb.InterfaceError, + r"query_single\(\) as it may return more than one element", ): - await self.client.query_single('SELECT {1, 2}') + await self.client.query_single("SELECT {1, 2}") with self.assertRaisesRegex( - edgedb.InterfaceError, - r'query_required_single\(\) as it may return ' - r'more than one element'): - await self.client.query_required_single('SELECT {1, 2}') + edgedb.InterfaceError, + r"query_required_single\(\) as it may return " + r"more than one element", + ): + await self.client.query_required_single("SELECT {1, 2}") with self.assertRaisesRegex( - edgedb.NoDataError, - r'\bquery_required_single\('): - await self.client.query_required_single('SELECT {}') + edgedb.NoDataError, r"\bquery_required_single\(" + ): + await self.client.query_required_single("SELECT {}") async def test_async_basic_datatypes_02(self): self.assertEqual( await self.client.query( - r'''select [b"\x00a", b"b", b'', b'\na']'''), - edgedb.Set([[b"\x00a", b"b", b'', b'\na']])) + r"""select [b"\x00a", b"b", b'', b'\na']""" + ), + edgedb.Set([[b"\x00a", b"b", b"", b"\na"]]), + ) self.assertEqual( - await self.client.query( - r'select $0', b'he\x00llo'), - edgedb.Set([b'he\x00llo'])) + await self.client.query(r"select $0", b"he\x00llo"), + edgedb.Set([b"he\x00llo"]), + ) async def test_async_basic_datatypes_03(self): for _ in range(10): # test opportunistic execute - self.assertEqual( - await self.client.query_json( - 'select ()'), - '[[]]') + self.assertEqual(await self.client.query_json("select ()"), "[[]]") self.assertEqual( - await self.client.query_json( - 'select (1,)'), - '[[1]]') + await self.client.query_json("select (1,)"), "[[1]]" + ) self.assertEqual( - await self.client.query_json( - 'select >[]'), - '[[]]') + await self.client.query_json("select >[]"), "[[]]" + ) self.assertEqual( - json.loads( - await self.client.query_json( - 'select ["a", "b"]')), - [["a", "b"]]) + json.loads(await self.client.query_json('select ["a", "b"]')), + [["a", "b"]], + ) self.assertEqual( json.loads( - await self.client.query_single_json( - 'select ["a", "b"]')), - ["a", "b"]) + await self.client.query_single_json('select ["a", "b"]') + ), + ["a", "b"], + ) self.assertEqual( json.loads( - await self.client.query_json(''' + await self.client.query_json( + """ SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; - ''')), + """ + ) + ), [ {"a": 42, "world": ["hello", 32]}, - {"a": 1, "world": ["yo", 10]} - ]) + {"a": 1, "world": ["yo", 10]}, + ], + ) self.assertEqual( - json.loads( - await self.client.query_json('SELECT {1, 2}')), - [1, 2]) + json.loads(await self.client.query_json("SELECT {1, 2}")), + [1, 2], + ) self.assertEqual( - json.loads(await self.client.query_json('SELECT {}')), - []) + json.loads(await self.client.query_json("SELECT {}")), + [], + ) with self.assertRaises(edgedb.NoDataError): await self.client.query_required_single_json( - 'SELECT {}' + "SELECT {}" ) self.assertEqual( json.loads( - await self.client.query_single_json('SELECT {}') + await self.client.query_single_json("SELECT {}") ), - None + None, ) async def test_async_basic_datatypes_04(self): val = await self.client.query_single( - ''' + """ SELECT schema::ObjectType { foo := { [(a := 1, b := 2), (a := 3, b := 4)], @@ -342,53 +362,65 @@ async def test_async_basic_datatypes_04(self): >>[], } } LIMIT 1 - ''' + """ ) self.assertEqual( val.foo, - edgedb.Set([ - edgedb.Array([ - edgedb.NamedTuple(a=1, b=2), - edgedb.NamedTuple(a=3, b=4), - ]), - edgedb.Array([ - edgedb.NamedTuple(a=5, b=6), - ]), - edgedb.Array([]), - ]), + edgedb.Set( + [ + edgedb.Array( + [ + edgedb.NamedTuple(a=1, b=2), + edgedb.NamedTuple(a=3, b=4), + ] + ), + edgedb.Array( + [ + edgedb.NamedTuple(a=5, b=6), + ] + ), + edgedb.Array([]), + ] + ), ) async def test_async_args_01(self): self.assertEqual( await self.client.query( - 'select (>$foo)[0] ++ (>$bar)[0];', - foo=['aaa'], bar=['bbb']), - edgedb.Set(('aaabbb',))) + "select (>$foo)[0] ++ (>$bar)[0];", + foo=["aaa"], + bar=["bbb"], + ), + edgedb.Set(("aaabbb",)), + ) async def test_async_args_02(self): self.assertEqual( await self.client.query( - 'select (>$0)[0] ++ (>$1)[0];', - ['aaa'], ['bbb']), - edgedb.Set(('aaabbb',))) + "select (>$0)[0] ++ (>$1)[0];", + ["aaa"], + ["bbb"], + ), + edgedb.Set(("aaabbb",)), + ) async def test_async_args_03(self): - with self.assertRaisesRegex(edgedb.QueryError, r'missing \$0'): - await self.client.query('select $1;') + with self.assertRaisesRegex(edgedb.QueryError, r"missing \$0"): + await self.client.query("select $1;") - with self.assertRaisesRegex(edgedb.QueryError, r'missing \$1'): - await self.client.query('select $0 + $2;') + with self.assertRaisesRegex(edgedb.QueryError, r"missing \$1"): + await self.client.query("select $0 + $2;") - with self.assertRaisesRegex(edgedb.QueryError, - 'combine positional and named parameters'): - await self.client.query('select $0 + $bar;') + with self.assertRaisesRegex( + edgedb.QueryError, "combine positional and named parameters" + ): + await self.client.query("select $0 + $bar;") - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - "None is not allowed"): - await self.client.query( - "select >$0", [1, None, 3] - ) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "None is not allowed" + ): + await self.client.query("select >$0", [1, None, 3]) async def test_async_args_04(self): aware_datetime = datetime.datetime.now(datetime.timezone.utc) @@ -400,57 +432,66 @@ async def test_async_args_04(self): self.assertEqual( await self.client.query_single( - 'select $0;', - aware_datetime), - aware_datetime) + "select $0;", aware_datetime + ), + aware_datetime, + ) self.assertEqual( await self.client.query_single( - 'select $0;', - naive_datetime), - naive_datetime) + "select $0;", naive_datetime + ), + naive_datetime, + ) self.assertEqual( await self.client.query_single( - 'select $0;', - date), - date) + "select $0;", date + ), + date, + ) self.assertEqual( await self.client.query_single( - 'select $0;', - naive_time), - naive_time) + "select $0;", naive_time + ), + naive_time, + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'a timezone-aware.*expected'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, r"a timezone-aware.*expected" + ): await self.client.query_single( - 'select $0;', - naive_datetime) + "select $0;", naive_datetime + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'a naive time object.*expected'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, r"a naive time object.*expected" + ): await self.client.query_single( - 'select $0;', - aware_time) + "select $0;", aware_time + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'a naive datetime object.*expected'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, r"a naive datetime object.*expected" + ): await self.client.query_single( - 'select $0;', - aware_datetime) + "select $0;", aware_datetime + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'datetime.datetime object was expected'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, + r"datetime.datetime object was expected", + ): await self.client.query_single( - 'select $0;', - date) + "select $0;", date + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'datetime.datetime object was expected'): - await self.client.query_single( - 'select $0;', - date) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, + r"datetime.datetime object was expected", + ): + await self.client.query_single("select $0;", date) async def _test_async_args_05(self): # XXX move to edgedb/edgedb # Argument's cardinality must affect the input type ID hash. @@ -458,13 +499,9 @@ async def _test_async_args_05(self): # XXX move to edgedb/edgedb # codec would be cached and then used for the second query, # which would make it fail. + self.assertEqual(await self.client.query("select $a", a=1), [1]) self.assertEqual( - await self.client.query('select $a', a=1), - [1] - ) - self.assertEqual( - await self.client.query('select $a', a=None), - [] + await self.client.query("select $a", a=None), [] ) async def _test_async_args_06(self): # XXX move to edgedb/edgedb @@ -473,71 +510,73 @@ async def _test_async_args_06(self): # XXX move to edgedb/edgedb # client side too. self.assertEqual( - await self.client.query('select $a', a=1), - [1] + await self.client.query("select $a", a=1), [1] ) with self.assertRaisesRegex( - edgedb.InvalidArgumentError, - r'argument \$a is required, but received None'): + edgedb.InvalidArgumentError, + r"argument \$a is required, but received None", + ): self.assertEqual( - await self.client.query('select $a', a=None), - [] + await self.client.query("select $a", a=None), [] ) async def test_async_mismatched_args_01(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'a'} arguments, " - "got {'[bc]', '[bc]'}, " - r"missed {'a'}, extra {'[bc]', '[bc]'}"): - + edgedb.QueryArgumentError, + r"expected {'a'} arguments, " + "got {'[bc]', '[bc]'}, " + r"missed {'a'}, extra {'[bc]', '[bc]'}", + ): await self.client.query("""SELECT $a;""", b=1, c=2) async def test_async_mismatched_args_02(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'[ab]', '[ab]'} arguments, " - r"got {'[acd]', '[acd]', '[acd]'}, " - r"missed {'b'}, extra {'[cd]', '[cd]'}"): - - await self.client.query(""" + edgedb.QueryArgumentError, + r"expected {'[ab]', '[ab]'} arguments, " + r"got {'[acd]', '[acd]', '[acd]'}, " + r"missed {'b'}, extra {'[cd]', '[cd]'}", + ): + await self.client.query( + """ SELECT $a + $b; - """, a=1, c=2, d=3) + """, + a=1, + c=2, + d=3, + ) async def test_async_mismatched_args_03(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - "expected {'a'} arguments, got {'b'}, " - "missed {'a'}, extra {'b'}"): - + edgedb.QueryArgumentError, + "expected {'a'} arguments, got {'b'}, " + "missed {'a'}, extra {'b'}", + ): await self.client.query("""SELECT $a;""", b=1) async def test_async_mismatched_args_04(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'[ab]', '[ab]'} arguments, " - r"got {'a'}, " - r"missed {'b'}"): - + edgedb.QueryArgumentError, + r"expected {'[ab]', '[ab]'} arguments, " + r"got {'a'}, " + r"missed {'b'}", + ): await self.client.query("""SELECT $a + $b;""", a=1) async def test_async_mismatched_args_05(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'a'} arguments, " - r"got {'[ab]', '[ab]'}, " - r"extra {'b'}"): - + edgedb.QueryArgumentError, + r"expected {'a'} arguments, " + r"got {'[ab]', '[ab]'}, " + r"extra {'b'}", + ): await self.client.query("""SELECT $a;""", a=1, b=2) async def test_async_mismatched_args_06(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'a'} arguments, " - r"got nothing, " - r"missed {'a'}"): - + edgedb.QueryArgumentError, + r"expected {'a'} arguments, " r"got nothing, " r"missed {'a'}", + ): await self.client.query("""SELECT $a;""") async def test_async_mismatched_args_07(self): @@ -545,40 +584,43 @@ async def test_async_mismatched_args_07(self): edgedb.QueryArgumentError, "expected no named arguments", ): - await self.client.query("""SELECT 42""", a=1, b=2) async def test_async_args_uuid_pack(self): obj = await self.client.query_single( - 'select schema::Object {id, name} limit 1') + "select schema::Object {id, name} limit 1" + ) # Test that the custom UUID that our driver uses can be # passed back as a parameter. ot = await self.client.query_single( - 'select schema::Object {name} filter .id=$id', - id=obj.id) + "select schema::Object {name} filter .id=$id", id=obj.id + ) self.assertEqual(obj.id, ot.id) self.assertEqual(obj.name, ot.name) # Test that a string UUID is acceptable. ot = await self.client.query_single( - 'select schema::Object {name} filter .id=$id', - id=str(obj.id)) + "select schema::Object {name} filter .id=$id", id=str(obj.id) + ) self.assertEqual(obj.id, ot.id) self.assertEqual(obj.name, ot.name) # Test that a standard uuid.UUID is acceptable. ot = await self.client.query_single( - 'select schema::Object {name} filter .id=$id', - id=uuid.UUID(bytes=obj.id.bytes)) + "select schema::Object {name} filter .id=$id", + id=uuid.UUID(bytes=obj.id.bytes), + ) self.assertEqual(obj.id, ot.id) self.assertEqual(obj.name, ot.name) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'invalid UUID.*length must be'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "invalid UUID.*length must be" + ): await self.client.query( - 'select schema::Object {name} filter .id=$id', - id='asdasas') + "select schema::Object {name} filter .id=$id", + id="asdasas", + ) async def test_async_args_bigint_basic(self): testar = [ @@ -638,97 +680,94 @@ async def test_async_args_bigint_basic(self): ] for _ in range(500): - num = '' + num = "" for _ in range(random.randint(1, 50)): num += random.choice("0123456789") testar.append(int(num)) for _ in range(500): - num = '' + num = "" for _ in range(random.randint(1, 50)): num += random.choice("0000000012") testar.append(int(num)) val = await self.client.query_single( - 'select >$arg', - arg=testar) + "select >$arg", arg=testar + ) self.assertEqual(testar, val) async def test_async_args_bigint_pack(self): - val = await self.client.query_single( - 'select $arg', - arg=10) + val = await self.client.query_single("select $arg", arg=10) self.assertEqual(val, 10) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - await self.client.query( - 'select $arg', - arg='bad int') + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + await self.client.query("select $arg", arg="bad int") - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - await self.client.query( - 'select $arg', - arg=10.11) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + await self.client.query("select $arg", arg=10.11) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): await self.client.query( - 'select $arg', - arg=decimal.Decimal('10.0')) + "select $arg", arg=decimal.Decimal("10.0") + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): await self.client.query( - 'select $arg', - arg=decimal.Decimal('10.11')) + "select $arg", arg=decimal.Decimal("10.11") + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - await self.client.query( - 'select $arg', - arg='10') + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + await self.client.query("select $arg", arg="10") - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): await self.client.query_single( - 'select $arg', - arg=decimal.Decimal('10')) + "select $arg", arg=decimal.Decimal("10") + ) + + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): class IntLike: def __int__(self): return 10 await self.client.query_single( - 'select $arg', - arg=IntLike()) + "select $arg", arg=IntLike() + ) async def test_async_args_intlike(self): class IntLike: def __int__(self): return 10 - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - await self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + await self.client.query_single("select $arg", arg=IntLike()) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - await self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + await self.client.query_single("select $arg", arg=IntLike()) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - await self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + await self.client.query_single("select $arg", arg=IntLike()) async def test_async_args_decimal(self): class IntLike: @@ -736,42 +775,52 @@ def __int__(self): return 10 val = await self.client.query_single( - 'select $0', decimal.Decimal("10.0") + "select $0", decimal.Decimal("10.0") ) self.assertEqual(val, 10) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected a Decimal or an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected a Decimal or an int" + ): await self.client.query_single( - 'select $arg', - arg=IntLike()) + "select $arg", arg=IntLike() + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected a Decimal or an int'): - await self.client.query_single( - 'select $arg', - arg="10.2") + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected a Decimal or an int" + ): + await self.client.query_single("select $arg", arg="10.2") async def test_async_range_01(self): has_range = await self.client.query( - "select schema::ObjectType filter .name = 'schema::Range'") + "select schema::ObjectType filter .name = 'schema::Range'" + ) if not has_range: raise unittest.SkipTest("server has no support for std::range") samples = [ - ('range', [ - edgedb.Range(1, 2, inc_lower=True, inc_upper=False), - dict( - input=edgedb.Range(1, 2, inc_lower=True, inc_upper=True), - output=edgedb.Range(1, 3, inc_lower=True, inc_upper=False), - ), - edgedb.Range(empty=True), - dict( - input=edgedb.Range(1, 1, inc_lower=True, inc_upper=False), - output=edgedb.Range(empty=True), - ), - edgedb.Range(lower=None, upper=None), - ]), + ( + "range", + [ + edgedb.Range(1, 2, inc_lower=True, inc_upper=False), + dict( + input=edgedb.Range( + 1, 2, inc_lower=True, inc_upper=True + ), + output=edgedb.Range( + 1, 3, inc_lower=True, inc_upper=False + ), + ), + edgedb.Range(empty=True), + dict( + input=edgedb.Range( + 1, 1, inc_lower=True, inc_upper=False + ), + output=edgedb.Range(empty=True), + ), + edgedb.Range(lower=None, upper=None), + ], + ), ] for typename, sample_data in samples: @@ -779,8 +828,8 @@ async def test_async_range_01(self): with self.subTest(sample=sample, typname=typename): stmt = f"SELECT <{typename}>$0" if isinstance(sample, dict): - inputval = sample['input'] - outputval = sample['output'] + inputval = sample["input"] + outputval = sample["output"] else: inputval = outputval = sample @@ -788,28 +837,32 @@ async def test_async_range_01(self): err_msg = ( "unexpected result for {} when passing {!r}: " "received {!r}, expected {!r}".format( - typename, inputval, result, outputval)) + typename, inputval, result, outputval + ) + ) self.assertEqual(result, outputval, err_msg) async def test_async_range_02(self): has_range = await self.client.query( - "select schema::ObjectType filter .name = 'schema::Range'") + "select schema::ObjectType filter .name = 'schema::Range'" + ) if not has_range: raise unittest.SkipTest("server has no support for std::range") result = await self.client.query_single( - "SELECT >>$0", - [edgedb.Range(1, 2)] + "SELECT >>$0", [edgedb.Range(1, 2)] ) self.assertEqual([edgedb.Range(1, 2)], result) async def test_async_wait_cancel_01(self): - underscored_lock = await self.client.query_single(""" + underscored_lock = await self.client.query_single( + """ SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_advisory_lock' ) - """) + """ + ) if not underscored_lock: self.skipTest("No sys::_advisory_lock function") @@ -820,20 +873,19 @@ async def test_async_wait_cancel_01(self): client = self.client.with_retry_options(RetryOptions(attempts=1)) client2 = self.make_test_client( database=self.client.dbname - ).with_retry_options( - RetryOptions(attempts=1) - ) + ).with_retry_options(RetryOptions(attempts=1)) await client2.ensure_connected() async for tx in client.transaction(): async with tx: - self.assertTrue(await tx.query_single( - 'select sys::_advisory_lock($0)', - lock_key)) + self.assertTrue( + await tx.query_single( + "select sys::_advisory_lock($0)", lock_key + ) + ) try: async with TaskGroup() as g: - fut = asyncio.Future() async def exec_to_fail(): @@ -843,12 +895,12 @@ async def exec_to_fail(): async for tx2 in client2.transaction(): async with tx2: # start the lazy transaction - await tx2.query('SELECT 42;') + await tx2.query("SELECT 42;") fut.set_result(None) await tx2.query( - 'select sys::_advisory_lock(' + - '$0)', + "select sys::_advisory_lock(" + + "$0)", lock_key, ) @@ -871,12 +923,14 @@ async def exec_to_fail(): finally: self.assertEqual( await tx.query( - 'select sys::_advisory_unlock($0)', - lock_key), - [True]) + "select sys::_advisory_unlock($0)", lock_key + ), + [True], + ) async def test_empty_set_unpack(self): - await self.client.query_single(''' + await self.client.query_single( + """ select schema::Function { name, params: { @@ -886,35 +940,39 @@ async def test_empty_set_unpack(self): } filter .name = 'std::str_repeat' limit 1 - ''') + """ + ) async def test_enum_argument_01(self): - A = await self.client.query_single('SELECT $0', 'A') - self.assertEqual(str(A), 'A') + A = await self.client.query_single("SELECT $0", "A") + self.assertEqual(str(A), "A") with self.assertRaisesRegex( - edgedb.InvalidValueError, 'invalid input value for enum'): + edgedb.InvalidValueError, "invalid input value for enum" + ): async for tx in self.client.transaction(): async with tx: - await tx.query_single('SELECT $0', 'Oups') + await tx.query_single("SELECT $0", "Oups") self.assertEqual( - await self.client.query_single('SELECT $0', 'A'), - A) + await self.client.query_single("SELECT $0", "A"), A + ) self.assertEqual( - await self.client.query_single('SELECT $0', A), - A) + await self.client.query_single("SELECT $0", A), A + ) with self.assertRaisesRegex( - edgedb.InvalidValueError, 'invalid input value for enum'): + edgedb.InvalidValueError, "invalid input value for enum" + ): async for tx in self.client.transaction(): async with tx: - await tx.query_single('SELECT $0', 'Oups') + await tx.query_single("SELECT $0", "Oups") with self.assertRaisesRegex( - edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue'): - await self.client.query_single('SELECT $0', 123) + edgedb.InvalidArgumentError, "a str or edgedb.EnumValue" + ): + await self.client.query_single("SELECT $0", 123) async def test_enum_argument_02(self): class MyEnum(enum.Enum): @@ -922,8 +980,8 @@ class MyEnum(enum.Enum): B = "B" C = "C" - A = await self.client.query_single('SELECT $0', MyEnum.A) - self.assertEqual(str(A), 'A') + A = await self.client.query_single("SELECT $0", MyEnum.A) + self.assertEqual(str(A), "A") self.assertEqual(A, MyEnum.A) self.assertEqual(MyEnum.A, A) self.assertLess(A, MyEnum.B) @@ -939,19 +997,18 @@ class MyEnum(enum.Enum): with self.assertRaises(ValueError): _ = A == MyEnum.C with self.assertRaises(edgedb.InvalidArgumentError): - await self.client.query_single('SELECT $0', MyEnum.C) + await self.client.query_single("SELECT $0", MyEnum.C) async def test_json(self): self.assertEqual( await self.client.query_json('SELECT {"aaa", "bbb"}'), - '["aaa", "bbb"]') + '["aaa", "bbb"]', + ) async def test_json_elements(self): result = await self.client.connection.raw_query( abstract.QueryContext( - query=abstract.QueryWithArgs( - 'SELECT {"aaa", "bbb"}', (), {} - ), + query=abstract.QueryWithArgs('SELECT {"aaa", "bbb"}', (), {}), cache=self.client._get_query_cache(), query_options=abstract.QueryOptions( output_format=protocol.OutputFormat.JSON_ELEMENTS, @@ -962,32 +1019,32 @@ async def test_json_elements(self): state=None, ) ) - self.assertEqual( - result, - edgedb.Set(['"aaa"', '"bbb"'])) + self.assertEqual(result, edgedb.Set(['"aaa"', '"bbb"'])) async def test_async_cancel_01(self): - has_sleep = await self.client.query_single(""" + has_sleep = await self.client.query_single( + """ SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_sleep' ) - """) + """ + ) if not has_sleep: self.skipTest("No sys::_sleep function") client = self.make_test_client(database=self.client.dbname) try: - self.assertEqual(await client.query_single('SELECT 1'), 1) + self.assertEqual(await client.query_single("SELECT 1"), 1) protocol_before = client._impl._holders[0]._con._protocol with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( - client.query_single('SELECT sys::_sleep(10)'), - timeout=0.1) + client.query_single("SELECT sys::_sleep(10)"), timeout=0.1 + ) - await client.query('SELECT 2') + await client.query("SELECT 2") protocol_after = client._impl._holders[0]._con._protocol self.assertIsNot( @@ -1005,31 +1062,37 @@ def on_log(con, msg): self.client.connection.add_log_listener(on_log) try: await self.client.query( - 'configure system set __internal_restart := true;') + "configure system set __internal_restart := true;" + ) await asyncio.sleep(0.01) # allow the loop to call the callback finally: self.client.connection.remove_log_listener(on_log) for msg in msgs: - if (msg.get_severity_name() == 'NOTICE' and - 'server restart is required' in str(msg)): + if ( + msg.get_severity_name() == "NOTICE" + and "server restart is required" in str(msg) + ): break else: - raise AssertionError('a notice message was not delivered') + raise AssertionError("a notice message was not delivered") async def test_async_banned_transaction(self): with self.assertRaisesRegex( - edgedb.CapabilityError, - r'cannot execute transaction control commands'): - await self.client.query('start transaction') + edgedb.CapabilityError, + r"cannot execute transaction control commands", + ): + await self.client.query("start transaction") with self.assertRaisesRegex( - edgedb.CapabilityError, - r'cannot execute transaction control commands'): - await self.client.execute('start transaction') + edgedb.CapabilityError, + r"cannot execute transaction control commands", + ): + await self.client.execute("start transaction") async def test_dup_link_prop_name(self): - obj = await self.client.query_single(''' + obj = await self.client.query_single( + """ CREATE TYPE test::dup_link_prop_name { CREATE PROPERTY val -> str; }; @@ -1050,27 +1113,36 @@ async def test_dup_link_prop_name(self): @val } } LIMIT 1; - ''') + """ + ) self.assertEqual(obj.l.val, "hello") self.assertEqual(obj.l["@val"], 42) - await self.client.execute(''' + await self.client.execute( + """ DROP TYPE test::dup_link_prop_name_p; DROP TYPE test::dup_link_prop_name; - ''') + """ + ) async def test_transaction_state(self): with self.assertRaisesRegex(edgedb.QueryError, "cannot assign to id"): async for tx in self.client.transaction(): async with tx: - await tx.execute(''' + await tx.execute( + """ INSERT test::Tmp { id := $0, tmp := '' } - ''', uuid.uuid4()) + """, + uuid.uuid4(), + ) client = self.client.with_config(allow_user_specified_id=True) async for tx in client.transaction(): async with tx: - await tx.execute(''' + await tx.execute( + """ INSERT test::Tmp { id := $0, tmp := '' } - ''', uuid.uuid4()) + """, + uuid.uuid4(), + ) diff --git a/tests/test_async_retry.py b/tests/test_async_retry.py index d47c5603..7b64c8e4 100644 --- a/tests/test_async_retry.py +++ b/tests/test_async_retry.py @@ -47,8 +47,7 @@ async def ready(self): class TestAsyncRetry(tb.AsyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE TYPE test::Counter EXTENDING std::Object { CREATE PROPERTY name -> std::str { CREATE CONSTRAINT std::exclusive; @@ -57,36 +56,42 @@ class TestAsyncRetry(tb.AsyncQueryTestCase): SET default := 0; }; }; - ''' + """ - TEARDOWN = ''' + TEARDOWN = """ DROP TYPE test::Counter; - ''' + """ async def test_async_retry_01(self): async for tx in self.client.transaction(): async with tx: - await tx.execute(''' + await tx.execute( + """ INSERT test::Counter { name := 'counter1' }; - ''') + """ + ) async def test_async_retry_02(self): with self.assertRaises(ZeroDivisionError): async for tx in self.client.transaction(): async with tx: - await tx.execute(''' + await tx.execute( + """ INSERT test::Counter { name := 'counter_retry_02' }; - ''') + """ + ) 1 / 0 with self.assertRaises(edgedb.NoDataError): - await self.client.query_required_single(''' + await self.client.query_required_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_02' - ''') + """ + ) async def test_async_retry_begin(self): patcher = unittest.mock.patch( @@ -107,16 +112,20 @@ def cleanup(): with self.assertRaises(errors.BackendUnavailableError): async for tx in self.client.transaction(): async with tx: - await tx.execute(''' + await tx.execute( + """ INSERT test::Counter { name := 'counter_retry_begin' }; - ''') + """ + ) with self.assertRaises(edgedb.NoDataError): - await self.client.query_required_single(''' + await self.client.query_required_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_begin' - ''') + """ + ) async def recover_after_first_error(*_, **__): patcher.stop() @@ -127,28 +136,32 @@ async def recover_after_first_error(*_, **__): async for tx in self.client.transaction(): async with tx: - await tx.execute(''' + await tx.execute( + """ INSERT test::Counter { name := 'counter_retry_begin' }; - ''') + """ + ) self.assertEqual(_start.call_count, call_count + 1) - await self.client.query_single(''' + await self.client.query_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_begin' - ''') + """ + ) async def test_async_retry_conflict(self): - await self.execute_conflict('counter2') + await self.execute_conflict("counter2") async def test_async_conflict_no_retry(self): with self.assertRaises(edgedb.TransactionSerializationError): await self.execute_conflict( - 'counter3', - RetryOptions(attempts=1, backoff=edgedb.default_backoff) + "counter3", + RetryOptions(attempts=1, backoff=edgedb.default_backoff), ) - async def execute_conflict(self, name='counter2', options=None): + async def execute_conflict(self, name="counter2", options=None): client2 = self.make_test_client(database=self.get_database_name()) self.addCleanup(client2.aclose) @@ -174,7 +187,8 @@ async def transaction1(client): await barrier.ready() await lock.acquire() - res = await tx.query_single(''' + res = await tx.query_single( + """ SELECT ( INSERT test::Counter { name := $name, @@ -185,7 +199,9 @@ async def transaction1(client): SET { value := .value + 1 } ) ).value - ''', name=name) + """, + name=name, + ) lock.release() return res @@ -194,11 +210,14 @@ async def transaction1(client): client = client.with_retry_options(options) client2 = client2.with_retry_options(options) - results = await asyncio.wait_for(asyncio.gather( - transaction1(client), - transaction1(client2), - return_exceptions=True, - ), 10) + results = await asyncio.wait_for( + asyncio.gather( + transaction1(client), + transaction1(client2), + return_exceptions=True, + ), + 10, + ) for e in results: if isinstance(e, BaseException): raise e @@ -230,8 +249,9 @@ async def test_async_transaction_interface_errors(self): async for tx in self.client.transaction(): await tx.start() - with self.assertRaisesRegex(edgedb.InterfaceError, - r'.*Use `async with transaction:`'): + with self.assertRaisesRegex( + edgedb.InterfaceError, r".*Use `async with transaction:`" + ): async for tx in self.client.transaction(): await tx.execute("SELECT 123") diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index 8ceeb239..cf3852fa 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -27,38 +27,42 @@ class TestAsyncTx(tb.AsyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE TYPE test::TransactionTest EXTENDING std::Object { CREATE PROPERTY name -> std::str; }; - ''' + """ - TEARDOWN = ''' + TEARDOWN = """ DROP TYPE test::TransactionTest; - ''' + """ async def test_async_transaction_regular_01(self): tr = self.client.with_retry_options( - RetryOptions(attempts=1)).transaction() + RetryOptions(attempts=1) + ).transaction() with self.assertRaises(ZeroDivisionError): async for with_tr in tr: async with with_tr: - await with_tr.execute(''' + await with_tr.execute( + """ INSERT test::TransactionTest { name := 'Test Transaction' }; - ''') + """ + ) 1 / 0 - result = await self.client.query(''' + result = await self.client.query( + """ SELECT test::TransactionTest FILTER test::TransactionTest.name = 'Test Transaction'; - ''') + """ + ) self.assertEqual(result, []) @@ -100,7 +104,7 @@ async def test_async_transaction_exclusive(self): with self.assertRaisesRegex( edgedb.InterfaceError, "concurrent queries within the same transaction " - "are not allowed" + "are not allowed", ): await asyncio.wait_for(f1, timeout=5) await asyncio.wait_for(f2, timeout=5) diff --git a/tests/test_asyncio_client.py b/tests/test_asyncio_client.py index 13e7c843..0e259abf 100644 --- a/tests/test_asyncio_client.py +++ b/tests/test_asyncio_client.py @@ -70,12 +70,14 @@ async def test_client_05(self): client = self.create_client(max_concurrency=10) async def worker(): - self.assertEqual(await client.query('SELECT 1'), [1]) - self.assertEqual(await client.query_single('SELECT 1'), 1) + self.assertEqual(await client.query("SELECT 1"), [1]) + self.assertEqual(await client.query_single("SELECT 1"), 1) self.assertEqual( - await client.query_json('SELECT 1'), '[1]') + await client.query_json("SELECT 1"), "[1]" + ) self.assertEqual( - await client.query_single_json('SELECT 1'), '1') + await client.query_single_json("SELECT 1"), "1" + ) tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks) @@ -94,9 +96,11 @@ async def test_client_options(self): client = self.create_client(max_concurrency=1) client.with_transaction_options( - edgedb.TransactionOptions(readonly=True)) + edgedb.TransactionOptions(readonly=True) + ) client.with_retry_options( - edgedb.RetryOptions(attempts=1, backoff=edgedb.default_backoff)) + edgedb.RetryOptions(attempts=1, backoff=edgedb.default_backoff) + ) async for tx in client.transaction(): async with tx: self.assertEqual(await tx.query_single("SELECT 7*8"), 56) @@ -112,12 +116,13 @@ async def test_client_no_acquire_deadlock(self): async with self.create_client( max_concurrency=1, ) as client: - - has_sleep = await client.query_single(""" + has_sleep = await client.query_single( + """ SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_sleep' ) - """) + """ + ) if not has_sleep: self.skipTest("No sys::_sleep function") @@ -146,7 +151,6 @@ async def test(client): max_concurrency=10, connection_class=MyConnection, ) as client: - await asyncio.gather(*[test(client) for _ in range(N)]) self.assertEqual( @@ -180,7 +184,6 @@ async def test_execute(client): async def run(N, meth): async with self.create_client(max_concurrency=10) as client: - coros = [meth(client) for _ in range(N)] res = await asyncio.gather(*coros) self.assertEqual(res, [1] * N) @@ -415,7 +418,7 @@ async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter): async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): ur, uw = await asyncio.open_connection( - con_args['host'], con_args['port'] + con_args["host"], con_args["port"] ) done.clear() task = self.loop.create_task(proxy(r, uw)) @@ -429,12 +432,10 @@ async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): w.close() uw.close() - server = await asyncio.start_server( - cb, '127.0.0.1', 0 - ) + server = await asyncio.start_server(cb, "127.0.0.1", 0) port = server.sockets[0].getsockname()[1] client = self.create_client( - host='127.0.0.1', + host="127.0.0.1", port=port, max_concurrency=1, wait_until_available=5, diff --git a/tests/test_blocking_client.py b/tests/test_blocking_client.py index 396356b5..ef6ad688 100644 --- a/tests/test_blocking_client.py +++ b/tests/test_blocking_client.py @@ -77,10 +77,10 @@ def test_client_05(self): client = self.create_client(max_concurrency=10) def worker(): - self.assertEqual(client.query('SELECT 1'), [1]) - self.assertEqual(client.query_single('SELECT 1'), 1) - self.assertEqual(client.query_json('SELECT 1'), '[1]') - self.assertEqual(client.query_single_json('SELECT 1'), '1') + self.assertEqual(client.query("SELECT 1"), [1]) + self.assertEqual(client.query_single("SELECT 1"), 1) + self.assertEqual(client.query_json("SELECT 1"), "[1]") + self.assertEqual(client.query_single_json("SELECT 1"), "1") tasks = [threading.Thread(target=worker) for _ in range(n)] for task in tasks: @@ -102,9 +102,11 @@ def test_client_options(self): client = self.create_client(max_concurrency=1) client.with_transaction_options( - edgedb.TransactionOptions(readonly=True)) + edgedb.TransactionOptions(readonly=True) + ) client.with_retry_options( - edgedb.RetryOptions(attempts=1, backoff=edgedb.default_backoff)) + edgedb.RetryOptions(attempts=1, backoff=edgedb.default_backoff) + ) for tx in client.transaction(): with tx: self.assertEqual(tx.query_single("SELECT 7*8"), 56) @@ -120,12 +122,13 @@ def test_client_no_acquire_deadlock(self): with self.create_client( max_concurrency=1, ) as client: - - has_sleep = client.query_single(""" + has_sleep = client.query_single( + """ SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_sleep' ) - """) + """ + ) if not has_sleep: self.skipTest("No sys::_sleep function") @@ -159,7 +162,6 @@ def test(): max_concurrency=10, connection_class=MyConnection, ) as client: - tasks = [threading.Thread(target=test) for _ in range(N)] for task in tasks: task.start() @@ -424,7 +426,7 @@ async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter): async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): ur, uw = await asyncio.open_connection( - con_args['host'], con_args['port'] + con_args["host"], con_args["port"] ) done.clear() task = self.loop.create_task(proxy(r, uw)) @@ -438,12 +440,10 @@ async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): w.close() uw.close() - server = await asyncio.start_server( - cb, '127.0.0.1', 0 - ) + server = await asyncio.start_server(cb, "127.0.0.1", 0) port = server.sockets[0].getsockname()[1] client = self.create_client( - host='127.0.0.1', + host="127.0.0.1", port=port, max_concurrency=1, wait_until_available=5, diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 102203fa..302db84f 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -74,13 +74,13 @@ async def run(*args, extra_env=None): else: if p.returncode: raise subprocess.CalledProcessError( - p.returncode, args, output=await p.stdout.read(), + p.returncode, + args, + output=await p.stdout.read(), ) cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "edgedb-py") - await run( - cmd, extra_env={"EDGEDB_PYTHON_CODEGEN_PY_VER": "3.8.5"} - ) + await run(cmd, extra_env={"EDGEDB_PYTHON_CODEGEN_PY_VER": "3.8.5"}) await run( cmd, "--target", diff --git a/tests/test_con_utils.py b/tests/test_con_utils.py index 742391ba..f7ef8d97 100644 --- a/tests/test_con_utils.py +++ b/tests/test_con_utils.py @@ -32,50 +32,65 @@ class TestConUtils(unittest.TestCase): - error_mapping = { - 'credentials_file_not_found': ( - RuntimeError, 'cannot read credentials'), - 'project_not_initialised': ( + "credentials_file_not_found": ( + RuntimeError, + "cannot read credentials", + ), + "project_not_initialised": ( + errors.ClientConnectionError, + "Found `edgedb.toml` but the project is not initialized", + ), + "no_options_or_toml": ( errors.ClientConnectionError, - 'Found `edgedb.toml` but the project is not initialized'), - 'no_options_or_toml': ( + "no `edgedb.toml` found and no connection options specified", + ), + "invalid_credentials_file": (RuntimeError, "cannot read credentials"), + "invalid_dsn_or_instance_name": ( + ValueError, + "invalid DSN or instance name", + ), + "invalid_instance_name": (ValueError, "invalid instance name"), + "invalid_dsn": (ValueError, "invalid DSN"), + "unix_socket_unsupported": ( + ValueError, + "unix socket paths not supported", + ), + "invalid_host": (ValueError, "invalid host"), + "invalid_port": (ValueError, "invalid port"), + "invalid_user": (ValueError, "invalid user"), + "invalid_database": (ValueError, "invalid database"), + "multiple_compound_env": ( errors.ClientConnectionError, - 'no `edgedb.toml` found and no connection options specified'), - 'invalid_credentials_file': ( - RuntimeError, 'cannot read credentials'), - 'invalid_dsn_or_instance_name': ( - ValueError, 'invalid DSN or instance name'), - 'invalid_instance_name': ( - ValueError, 'invalid instance name'), - 'invalid_dsn': (ValueError, 'invalid DSN'), - 'unix_socket_unsupported': ( - ValueError, 'unix socket paths not supported'), - 'invalid_host': (ValueError, 'invalid host'), - 'invalid_port': (ValueError, 'invalid port'), - 'invalid_user': (ValueError, 'invalid user'), - 'invalid_database': (ValueError, 'invalid database'), - 'multiple_compound_env': ( + "Cannot have more than one of the following connection " + + "environment variables", + ), + "multiple_compound_opts": ( errors.ClientConnectionError, - 'Cannot have more than one of the following connection ' - + 'environment variables'), - 'multiple_compound_opts': ( + "Cannot have more than one of the following connection options", + ), + "exclusive_options": ( errors.ClientConnectionError, - 'Cannot have more than one of the following connection options'), - 'exclusive_options': ( + "are mutually exclusive", + ), + "env_not_found": ( + ValueError, + 'environment variable ".*" doesn\'t exist', + ), + "file_not_found": (FileNotFoundError, "No such file or directory"), + "invalid_tls_security": ( + ValueError, + "tls_security can only be one of `insecure`, " + "|tls_security must be set to strict", + ), + "invalid_secret_key": ( errors.ClientConnectionError, - 'are mutually exclusive'), - 'env_not_found': ( - ValueError, 'environment variable ".*" doesn\'t exist'), - 'file_not_found': (FileNotFoundError, 'No such file or directory'), - 'invalid_tls_security': ( - ValueError, 'tls_security can only be one of `insecure`, ' - '|tls_security must be set to strict'), - 'invalid_secret_key': ( - errors.ClientConnectionError, "Invalid secret key"), - 'secret_key_not_found': ( + "Invalid secret key", + ), + "secret_key_not_found": ( errors.ClientConnectionError, - "Cannot connect to cloud instances without secret key"), + "Cannot connect to cloud instances without secret key", + ), } @contextlib.contextmanager @@ -102,52 +117,60 @@ def environ(self, **kwargs): os.environ[key] = val def run_testcase(self, testcase): - env = testcase.get('env', {}) - test_env = {'EDGEDB_HOST': None, 'EDGEDB_PORT': None, - 'EDGEDB_USER': None, 'EDGEDB_PASSWORD': None, - 'EDGEDB_SECRET_KEY': None, - 'EDGEDB_DATABASE': None, 'PGSSLMODE': None, - 'XDG_CONFIG_HOME': None} + env = testcase.get("env", {}) + test_env = { + "EDGEDB_HOST": None, + "EDGEDB_PORT": None, + "EDGEDB_USER": None, + "EDGEDB_PASSWORD": None, + "EDGEDB_SECRET_KEY": None, + "EDGEDB_DATABASE": None, + "PGSSLMODE": None, + "XDG_CONFIG_HOME": None, + } test_env.update(env) - fs = testcase.get('fs') - - opts = testcase.get('opts', {}) - dsn = opts['instance'] if 'instance' in opts else opts.get('dsn') - credentials = opts.get('credentials') - credentials_file = opts.get('credentialsFile') - host = opts.get('host') - port = opts.get('port') - database = opts.get('database') - user = opts.get('user') - password = opts.get('password') - secret_key = opts.get('secretKey') - tls_ca = opts.get('tlsCA') - tls_ca_file = opts.get('tlsCAFile') - tls_security = opts.get('tlsSecurity') - server_settings = opts.get('serverSettings') - wait_until_available = opts.get('waitUntilAvailable') - - other_opts = testcase.get('other_opts', {}) - timeout = other_opts.get('timeout') - command_timeout = other_opts.get('command_timeout') - - expected = testcase.get('result') - expected_error = testcase.get('error') - if expected_error and expected_error.get('type'): - expected_error = self.error_mapping.get(expected_error.get('type')) + fs = testcase.get("fs") + + opts = testcase.get("opts", {}) + dsn = opts["instance"] if "instance" in opts else opts.get("dsn") + credentials = opts.get("credentials") + credentials_file = opts.get("credentialsFile") + host = opts.get("host") + port = opts.get("port") + database = opts.get("database") + user = opts.get("user") + password = opts.get("password") + secret_key = opts.get("secretKey") + tls_ca = opts.get("tlsCA") + tls_ca_file = opts.get("tlsCAFile") + tls_security = opts.get("tlsSecurity") + server_settings = opts.get("serverSettings") + wait_until_available = opts.get("waitUntilAvailable") + + other_opts = testcase.get("other_opts", {}) + timeout = other_opts.get("timeout") + command_timeout = other_opts.get("command_timeout") + + expected = testcase.get("result") + expected_error = testcase.get("error") + if expected_error and expected_error.get("type"): + expected_error = self.error_mapping.get(expected_error.get("type")) if not expected_error: raise RuntimeError( - f"unknown error type: {testcase.get('error').get('type')}") + f"unknown error type: {testcase.get('error').get('type')}" + ) if expected is None and expected_error is None: raise RuntimeError( 'invalid test case: either "result" or "error" key ' - 'has to be specified') + "has to be specified" + ) if expected is not None and expected_error is not None: raise RuntimeError( 'invalid test case: either "result" or "error" key ' - 'has to be specified, got both') + "has to be specified, got both" + ) result = None with contextlib.ExitStack() as es: @@ -156,56 +179,58 @@ def run_testcase(self, testcase): stat_result = os.stat(os.getcwd()) es.enter_context( - mock.patch('os.stat', lambda _, **__: stat_result) + mock.patch("os.stat", lambda _, **__: stat_result) ) if fs: - cwd = fs.get('cwd') - homedir = fs.get('homedir') - files = fs.get('files') + cwd = fs.get("cwd") + homedir = fs.get("homedir") + files = fs.get("files") if cwd: - es.enter_context(mock.patch('os.getcwd', lambda: cwd)) + es.enter_context(mock.patch("os.getcwd", lambda: cwd)) if homedir: homedir = pathlib.Path(homedir) es.enter_context( - mock.patch('pathlib.Path.home', lambda: homedir) + mock.patch("pathlib.Path.home", lambda: homedir) ) if files: for f, v in files.copy().items(): if "${HASH}" in f: - hash = con_utils._hash_path(v['project-path']) + hash = con_utils._hash_path(v["project-path"]) dir = f.replace("${HASH}", hash) files[dir] = "" - instance = os.path.join(dir, 'instance-name') - files[instance] = v['instance-name'] - project = os.path.join(dir, 'project-path') - files[project] = v['project-path'] - if 'cloud-profile' in v: - profile = os.path.join(dir, 'cloud-profile') - files[profile] = v['cloud-profile'] + instance = os.path.join(dir, "instance-name") + files[instance] = v["instance-name"] + project = os.path.join(dir, "project-path") + files[project] = v["project-path"] + if "cloud-profile" in v: + profile = os.path.join(dir, "cloud-profile") + files[profile] = v["cloud-profile"] del files[f] es.enter_context( mock.patch( - 'os.path.exists', - lambda filepath: str(filepath) in files + "os.path.exists", + lambda filepath: str(filepath) in files, ) ) es.enter_context( mock.patch( - 'os.path.isfile', - lambda filepath: str(filepath) in files + "os.path.isfile", + lambda filepath: str(filepath) in files, ) ) - es.enter_context(mock.patch( - 'os.stat', - lambda d, **_: mock.Mock(st_dev=0), - )) + es.enter_context( + mock.patch( + "os.stat", + lambda d, **_: mock.Mock(st_dev=0), + ) + ) es.enter_context( - mock.patch('os.path.realpath', lambda f: f) + mock.patch("os.path.realpath", lambda f: f) ) def mocked_open(filepath, *args, **kwargs): @@ -214,10 +239,11 @@ def mocked_open(filepath, *args, **kwargs): read_data=files.get(str(filepath)) )() raise FileNotFoundError( - f"[Errno 2] No such file or directory: " + - f"'{filepath}'" + "[Errno 2] No such file or directory: " + + f"'{filepath}'" ) - es.enter_context(mock.patch('builtins.open', mocked_open)) + + es.enter_context(mock.patch("builtins.open", mocked_open)) if expected_error: es.enter_context(self.assertRaisesRegex(*expected_error)) @@ -242,70 +268,69 @@ def mocked_open(filepath, *args, **kwargs): ) result = { - 'address': [ - connect_config.address[0], connect_config.address[1] + "address": [ + connect_config.address[0], + connect_config.address[1], ], - 'database': connect_config.database, - 'user': connect_config.user, - 'password': connect_config.password, - 'secretKey': connect_config.secret_key, - 'tlsCAData': connect_config._tls_ca_data, - 'tlsSecurity': connect_config.tls_security, - 'serverSettings': connect_config.server_settings, - 'waitUntilAvailable': client_config.wait_until_available, + "database": connect_config.database, + "user": connect_config.user, + "password": connect_config.password, + "secretKey": connect_config.secret_key, + "tlsCAData": connect_config._tls_ca_data, + "tlsSecurity": connect_config.tls_security, + "serverSettings": connect_config.server_settings, + "waitUntilAvailable": client_config.wait_until_available, } if expected is not None: self.assertEqual(expected, result) def test_test_connect_params_environ(self): - self.assertNotIn('AAAAAAAAAA123', os.environ) - self.assertNotIn('AAAAAAAAAA456', os.environ) - self.assertNotIn('AAAAAAAAAA789', os.environ) + self.assertNotIn("AAAAAAAAAA123", os.environ) + self.assertNotIn("AAAAAAAAAA456", os.environ) + self.assertNotIn("AAAAAAAAAA789", os.environ) try: + os.environ["AAAAAAAAAA456"] = "123" + os.environ["AAAAAAAAAA789"] = "123" - os.environ['AAAAAAAAAA456'] = '123' - os.environ['AAAAAAAAAA789'] = '123' + with self.environ( + AAAAAAAAAA123="1", AAAAAAAAAA456="2", AAAAAAAAAA789=None + ): + self.assertEqual(os.environ["AAAAAAAAAA123"], "1") + self.assertEqual(os.environ["AAAAAAAAAA456"], "2") + self.assertNotIn("AAAAAAAAAA789", os.environ) - with self.environ(AAAAAAAAAA123='1', - AAAAAAAAAA456='2', - AAAAAAAAAA789=None): - - self.assertEqual(os.environ['AAAAAAAAAA123'], '1') - self.assertEqual(os.environ['AAAAAAAAAA456'], '2') - self.assertNotIn('AAAAAAAAAA789', os.environ) - - self.assertNotIn('AAAAAAAAAA123', os.environ) - self.assertEqual(os.environ['AAAAAAAAAA456'], '123') - self.assertEqual(os.environ['AAAAAAAAAA789'], '123') + self.assertNotIn("AAAAAAAAAA123", os.environ) + self.assertEqual(os.environ["AAAAAAAAAA456"], "123") + self.assertEqual(os.environ["AAAAAAAAAA789"], "123") finally: - for key in {'AAAAAAAAAA123', 'AAAAAAAAAA456', 'AAAAAAAAAA789'}: + for key in {"AAAAAAAAAA123", "AAAAAAAAAA456", "AAAAAAAAAA789"}: if key in os.environ: del os.environ[key] def test_test_connect_params_run_testcase(self): - with self.environ(EDGEDB_PORT='777'): - self.run_testcase({ - 'env': { - 'EDGEDB_HOST': 'abc' - }, - 'opts': { - 'user': '__test__', - }, - 'result': { - 'address': ['abc', 5656], - 'database': 'edgedb', - 'user': '__test__', - 'password': None, - 'secretKey': None, - 'tlsCAData': None, - 'tlsSecurity': 'strict', - 'serverSettings': {}, - 'waitUntilAvailable': 30, - }, - }) + with self.environ(EDGEDB_PORT="777"): + self.run_testcase( + { + "env": {"EDGEDB_HOST": "abc"}, + "opts": { + "user": "__test__", + }, + "result": { + "address": ["abc", 5656], + "database": "edgedb", + "user": "__test__", + "password": None, + "secretKey": None, + "tlsCAData": None, + "tlsSecurity": "strict", + "serverSettings": {}, + "waitUntilAvailable": 30, + }, + } + ) def test_connect_params(self): testcases_path = ( @@ -318,136 +343,149 @@ def test_connect_params(self): testcases = json.load(f) except FileNotFoundError as err: raise FileNotFoundError( - f'Failed to read "connection_testcases.json": {err}.\n' + - f'Is the "shared-client-testcases" submodule initialised? ' + - f'Try running "git submodule update --init".' + f'Failed to read "connection_testcases.json": {err}.\n' + + 'Is the "shared-client-testcases" submodule initialised? ' + + 'Try running "git submodule update --init".' ) for i, testcase in enumerate(testcases): with self.subTest(i=i): - wait_until_available = \ - testcase.get('result', {}).get('waitUntilAvailable') + wait_until_available = testcase.get("result", {}).get( + "waitUntilAvailable" + ) if wait_until_available: - testcase['result']['waitUntilAvailable'] = \ - con_utils._validate_wait_until_available( - wait_until_available) - platform = testcase.get('platform') - if testcase.get('fs') and ( - sys.platform == 'win32' or platform == 'windows' - or (platform is None and sys.platform == 'darwin') - or (platform == 'macos' and sys.platform != 'darwin') + testcase["result"][ + "waitUntilAvailable" + ] = con_utils._validate_wait_until_available( + wait_until_available + ) + platform = testcase.get("platform") + if testcase.get("fs") and ( + sys.platform == "win32" + or platform == "windows" + or (platform is None and sys.platform == "darwin") + or (platform == "macos" and sys.platform != "darwin") ): continue self.run_testcase(testcase) - @mock.patch("edgedb.platform.config_dir", - lambda: pathlib.Path("/home/user/.config/edgedb")) + @mock.patch( + "edgedb.platform.config_dir", + lambda: pathlib.Path("/home/user/.config/edgedb"), + ) @mock.patch("edgedb.platform.IS_WINDOWS", False) @mock.patch("pathlib.Path.exists", lambda p: True) @mock.patch("os.path.realpath", lambda p: p) def test_stash_path(self): self.assertEqual( con_utils._stash_path("/home/user/work/project1"), - pathlib.Path("/home/user/.config/edgedb/projects/project1-" - "cf1c841351bf7f147d70dcb6203441cf77a05249"), + pathlib.Path( + "/home/user/.config/edgedb/projects/project1-" + "cf1c841351bf7f147d70dcb6203441cf77a05249" + ), ) def test_project_config(self): with tempfile.TemporaryDirectory() as tmp: tmp = os.path.realpath(tmp) base = pathlib.Path(tmp) - home = base / 'home' - project = base / 'project' - projects = home / '.edgedb' / 'projects' - creds = home / '.edgedb' / 'credentials' + home = base / "home" + project = base / "project" + projects = home / ".edgedb" / "projects" + creds = home / ".edgedb" / "credentials" os.makedirs(projects) os.makedirs(creds) os.makedirs(project) - with open(project / 'edgedb.toml', 'wt') as f: - f.write('') # app don't read toml file - with open(creds / 'inst1.json', 'wt') as f: - f.write(json.dumps({ - "host": "inst1.example.org", - "port": 12323, - "user": "inst1_user", - "password": "passw1", - "database": "inst1_db", - })) - - with mock.patch('edgedb.platform.config_dir', - lambda: home / '.edgedb'), \ - mock.patch('os.getcwd', lambda: str(project)): + with open(project / "edgedb.toml", "wt") as f: + f.write("") # app don't read toml file + with open(creds / "inst1.json", "wt") as f: + f.write( + json.dumps( + { + "host": "inst1.example.org", + "port": 12323, + "user": "inst1_user", + "password": "passw1", + "database": "inst1_db", + } + ) + ) + + with mock.patch( + "edgedb.platform.config_dir", lambda: home / ".edgedb" + ), mock.patch("os.getcwd", lambda: str(project)): stash_path = con_utils._stash_path(project) - instance_file = stash_path / 'instance-name' + instance_file = stash_path / "instance-name" os.makedirs(stash_path) - with open(instance_file, 'wt') as f: - f.write('inst1') - - connect_config, client_config = ( - con_utils.parse_connect_arguments( - dsn=None, - host=None, - port=None, - credentials=None, - credentials_file=None, - user=None, - password=None, - secret_key=None, - database=None, - tls_ca=None, - tls_ca_file=None, - tls_security=None, - timeout=10, - command_timeout=None, - server_settings=None, - wait_until_available=30, - ) + with open(instance_file, "wt") as f: + f.write("inst1") + + ( + connect_config, + client_config, + ) = con_utils.parse_connect_arguments( + dsn=None, + host=None, + port=None, + credentials=None, + credentials_file=None, + user=None, + password=None, + secret_key=None, + database=None, + tls_ca=None, + tls_ca_file=None, + tls_security=None, + timeout=10, + command_timeout=None, + server_settings=None, + wait_until_available=30, ) - self.assertEqual(connect_config.address, ('inst1.example.org', 12323)) - self.assertEqual(connect_config.user, 'inst1_user') - self.assertEqual(connect_config.password, 'passw1') - self.assertEqual(connect_config.database, 'inst1_db') + self.assertEqual(connect_config.address, ("inst1.example.org", 12323)) + self.assertEqual(connect_config.user, "inst1_user") + self.assertEqual(connect_config.password, "passw1") + self.assertEqual(connect_config.database, "inst1_db") def test_validate_wait_until_available(self): invalid = [ - ' ', - ' PT1S', - '', - '-.1 s', - '-.1s', - '-.5 second', - '-.5 seconds', - '-.5second', - '-.5seconds', - '-.s', - '-1.s', - '.s', - '.seconds', - '1.s', - '1h-120m3600s', - '1hour-120minute3600second', - '1hours-120minutes3600seconds', - '1hours120minutes3600seconds', - '2.0hour46.0minutes39.0seconds', - '2.0hours46.0minutes39.0seconds', - '20 hours with other stuff should not be valid', - '20 minutes with other stuff should not be valid', - '20 ms with other stuff should not be valid', - '20 seconds with other stuff should not be valid', - '20 us with other stuff should not be valid', - '2hour46minute39second', - '2hours46minutes39seconds', - '3 hours is longer than 10 seconds', - 'P-.D', - 'P-D', - 'PD', - 'PT.S', - 'PT1S ', - '\t', - 'not a duration', - 's', + " ", + " PT1S", + "", + "-.1 s", + "-.1s", + "-.5 second", + "-.5 seconds", + "-.5second", + "-.5seconds", + "-.s", + "-1.s", + ".s", + ".seconds", + "1.s", + "1h-120m3600s", + "1hour-120minute3600second", + "1hours-120minutes3600seconds", + "1hours120minutes3600seconds", + "2.0hour46.0minutes39.0seconds", + "2.0hours46.0minutes39.0seconds", + "20 hours with other stuff should not be valid", + "20 minutes with other stuff should not be valid", + "20 ms with other stuff should not be valid", + "20 seconds with other stuff should not be valid", + "20 us with other stuff should not be valid", + "2hour46minute39second", + "2hours46minutes39seconds", + "3 hours is longer than 10 seconds", + "P-.D", + "P-D", + "PD", + "PT.S", + "PT1S ", + "\t", + "not a duration", + "s", ] for string in invalid: @@ -586,6 +624,5 @@ def test_validate_wait_until_available(self): for string, expected in valid: with self.subTest(duration=string): self.assertEqual( - expected, - con_utils._validate_wait_until_available(string) + expected, con_utils._validate_wait_until_available(string) ) diff --git a/tests/test_connect.py b/tests/test_connect.py index e22f1958..b5f6b400 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -25,7 +25,6 @@ class TestConnect(tb.AsyncQueryTestCase): - @classmethod def setUpClass(cls): super().setUpClass() @@ -35,7 +34,7 @@ def setUpClass(cls): def _get_free_port(cls): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - sock.bind(('127.0.0.1', 0)) + sock.bind(("127.0.0.1", 0)) return sock.getsockname()[1] except Exception: return None @@ -45,35 +44,39 @@ def _get_free_port(cls): async def test_connect_async_01(self): orig_conn_args = self.get_connect_args() conn_args = orig_conn_args.copy() - conn_args['port'] = self.port - conn_args['wait_until_available'] = 0 + conn_args["port"] = self.port + conn_args["wait_until_available"] = 0 with self.assertRaisesRegex( - edgedb.ClientConnectionError, - f'(?s).*Is the server running.*port {self.port}.*'): - conn_args['host'] = '127.0.0.1' + edgedb.ClientConnectionError, + f"(?s).*Is the server running.*port {self.port}.*", + ): + conn_args["host"] = "127.0.0.1" await edgedb.create_async_client(**conn_args).ensure_connected() with self.assertRaisesRegex( - edgedb.ClientConnectionError, - f'(?s).*Is the server running.*port {self.port}.*'): - conn_args['host'] = orig_conn_args['host'] + edgedb.ClientConnectionError, + f"(?s).*Is the server running.*port {self.port}.*", + ): + conn_args["host"] = orig_conn_args["host"] await edgedb.create_async_client(**conn_args).ensure_connected() def test_connect_sync_01(self): orig_conn_args = self.get_connect_args() conn_args = orig_conn_args.copy() - conn_args['port'] = self.port - conn_args['wait_until_available'] = 0 + conn_args["port"] = self.port + conn_args["wait_until_available"] = 0 with self.assertRaisesRegex( - edgedb.ClientConnectionError, - f'(?s).*Is the server running.*port {self.port}.*'): - conn_args['host'] = '127.0.0.1' + edgedb.ClientConnectionError, + f"(?s).*Is the server running.*port {self.port}.*", + ): + conn_args["host"] = "127.0.0.1" edgedb.create_client(**conn_args).ensure_connected() with self.assertRaisesRegex( - edgedb.ClientConnectionError, - f'(?s).*Is the server running.*port {self.port}.*'): - conn_args['host'] = orig_conn_args['host'] + edgedb.ClientConnectionError, + f"(?s).*Is the server running.*port {self.port}.*", + ): + conn_args["host"] = orig_conn_args["host"] edgedb.create_client(**conn_args).ensure_connected() diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 663d9c3d..c45679ca 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -45,48 +45,62 @@ def tearDown(self): def test_credentials_read(self): creds = credentials.read_credentials( - pathlib.Path(__file__).parent / 'credentials1.json') - self.assertEqual(creds, { - 'database': 'test3n', - 'password': 'lZTBy1RVCfOpBAOwSCwIyBIR', - 'port': 10702, - 'user': 'test3n', - }) + pathlib.Path(__file__).parent / "credentials1.json" + ) + self.assertEqual( + creds, + { + "database": "test3n", + "password": "lZTBy1RVCfOpBAOwSCwIyBIR", + "port": 10702, + "user": "test3n", + }, + ) def test_credentials_empty(self): - with self.assertRaisesRegex(ValueError, '`user` key is required'): + with self.assertRaisesRegex(ValueError, "`user` key is required"): credentials.validate_credentials({}) def test_credentials_port(self): - with self.assertRaisesRegex(ValueError, 'invalid `port` value'): - credentials.validate_credentials({ - 'user': 'u1', - 'port': '1234', - }) - - with self.assertRaisesRegex(ValueError, 'invalid `port` value'): - credentials.validate_credentials({ - 'user': 'u1', - 'port': 0, - }) - - with self.assertRaisesRegex(ValueError, 'invalid `port` value'): - credentials.validate_credentials({ - 'user': 'u1', - 'port': -1, - }) - - with self.assertRaisesRegex(ValueError, 'invalid `port` value'): - credentials.validate_credentials({ - 'user': 'u1', - 'port': 65536, - }) + with self.assertRaisesRegex(ValueError, "invalid `port` value"): + credentials.validate_credentials( + { + "user": "u1", + "port": "1234", + } + ) + + with self.assertRaisesRegex(ValueError, "invalid `port` value"): + credentials.validate_credentials( + { + "user": "u1", + "port": 0, + } + ) + + with self.assertRaisesRegex(ValueError, "invalid `port` value"): + credentials.validate_credentials( + { + "user": "u1", + "port": -1, + } + ) + + with self.assertRaisesRegex(ValueError, "invalid `port` value"): + credentials.validate_credentials( + { + "user": "u1", + "port": 65536, + } + ) def test_credentials_extra_key(self): - creds = credentials.validate_credentials(dict( - user='user1', - some_extra_data='test', - )) + creds = credentials.validate_credentials( + dict( + user="user1", + some_extra_data="test", + ) + ) # extra keys are ignored for forward compatibility # but aren't exported through validator self.assertEqual(creds, {"user": "user1", "port": 5656}) @@ -97,7 +111,9 @@ def test_get_credentials_path_macos(self, home_method): importlib.reload(platform) home_method.return_value = pathlib.PurePosixPath("/Users/edgedb") with mock.patch( - "pathlib.PurePosixPath.exists", lambda x: True, create=True, + "pathlib.PurePosixPath.exists", + lambda x: True, + create=True, ): self.assertEqual( str(credentials.get_credentials_path("test")), @@ -105,7 +121,9 @@ def test_get_credentials_path_macos(self, home_method): "edgedb/credentials/test.json", ) with mock.patch( - "pathlib.PurePosixPath.exists", _MockExists(), create=True, + "pathlib.PurePosixPath.exists", + _MockExists(), + create=True, ): self.assertEqual( str(credentials.get_credentials_path("test")), @@ -135,7 +153,7 @@ def get_folder_path(_a, _b, _c, _d, path_buf): with mock.patch( "pathlib.PureWindowsPath.exists", _MockExists(), create=True ), mock.patch( - 'pathlib.PureWindowsPath.home', + "pathlib.PureWindowsPath.home", lambda: pathlib.PureWindowsPath(r"c:\Users\edgedb"), create=True, ): diff --git a/tests/test_datetime.py b/tests/test_datetime.py index 08199077..fa87cc8c 100644 --- a/tests/test_datetime.py +++ b/tests/test_datetime.py @@ -26,9 +26,7 @@ class TestDatetimeTypes(tb.SyncQueryTestCase): - async def test_duration_01(self): - duration_kwargs = [ dict(), dict(microseconds=1), @@ -54,10 +52,13 @@ async def test_duration_01(self): durs = [timedelta(**d) for d in duration_kwargs] # Test encode/decode roundtrip - durs_from_db = self.client.query(''' + durs_from_db = self.client.query( + """ WITH args := array_unpack(>$0) SELECT args; - ''', durs) + """, + durs, + ) self.assertEqual(list(durs_from_db), durs) async def test_relative_duration_01(self): @@ -84,7 +85,7 @@ async def test_relative_duration_01(self): dict( microseconds=random.randint(-1000000000, 1000000000), days=random.randint(-500, 500), - months=random.randint(-50, 50) + months=random.randint(-50, 50), ) ) @@ -92,16 +93,22 @@ async def test_relative_duration_01(self): # Test that RelativeDuration.__str__ formats the # same as - durs_as_text = self.client.query(''' + durs_as_text = self.client.query( + """ WITH args := array_unpack(>$0) SELECT args; - ''', durs) + """, + durs, + ) # Test encode/decode roundtrip - durs_from_db = self.client.query(''' + durs_from_db = self.client.query( + """ WITH args := array_unpack(>$0) SELECT args; - ''', durs) + """, + durs, + ) self.assertEqual(durs_as_text, [str(d) for d in durs]) self.assertEqual(list(durs_from_db), durs) @@ -145,7 +152,7 @@ async def test_date_duration_01(self): delta_kwargs.append( dict( days=random.randint(-500, 500), - months=random.randint(-50, 50) + months=random.randint(-50, 50), ) ) @@ -153,16 +160,22 @@ async def test_date_duration_01(self): # Test that DateDuration.__str__ formats the # same as - durs_as_text = self.client.query(''' + durs_as_text = self.client.query( + """ WITH args := array_unpack(>$0) SELECT args; - ''', durs) + """, + durs, + ) # Test encode/decode roundtrip - durs_from_db = self.client.query(''' + durs_from_db = self.client.query( + """ WITH args := array_unpack(>$0) SELECT args; - ''', durs) + """, + durs, + ) for db_dur, client_dur in zip(durs_as_text, durs): self.assertEqual(db_dur, str(client_dur)) diff --git a/tests/test_enum.py b/tests/test_enum.py index 1f3b6fae..8013bba7 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -26,11 +26,10 @@ class TestEnum(tb.AsyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE SCALAR TYPE CellType EXTENDING enum<'red', 'white'>; CREATE SCALAR TYPE Color EXTENDING enum<'red', 'white'>; - ''' + """ async def test_enum_01(self): ct_red = await self.client.query_single('SELECT "red"') @@ -42,11 +41,11 @@ async def test_enum_01(self): self.assertEqual(repr(ct_red), "") - self.assertEqual(str(ct_red), 'red') + self.assertEqual(str(ct_red), "red") with self.assertRaises(TypeError): - _ = ct_red != 'red' + _ = ct_red != "red" with self.assertRaises(TypeError): - _ = ct_red == 'red' + _ = ct_red == "red" self.assertFalse(ct_red == c_red) self.assertEqual(ct_red, ct_red) @@ -59,13 +58,13 @@ async def test_enum_01(self): self.assertGreaterEqual(ct_white, ct_white) with self.assertRaises(TypeError): - _ = ct_red < 'red' + _ = ct_red < "red" with self.assertRaises(TypeError): - _ = ct_red > 'red' + _ = ct_red > "red" with self.assertRaises(TypeError): - _ = ct_red <= 'red' + _ = ct_red <= "red" with self.assertRaises(TypeError): - _ = ct_red >= 'red' + _ = ct_red >= "red" with self.assertRaises(TypeError): _ = ct_red < c_red @@ -77,17 +76,17 @@ async def test_enum_01(self): _ = ct_red >= c_red self.assertEqual(hash(ct_red), hash(c_red)) - self.assertEqual(hash(ct_red), hash('red')) + self.assertEqual(hash(ct_red), hash("red")) async def test_enum_02(self): c_red = await self.client.query_single('SELECT "red"') self.assertIsInstance(c_red, enum.Enum) - self.assertEqual(c_red.name, 'RED') - self.assertEqual(c_red.value, 'red') + self.assertEqual(c_red.name, "RED") + self.assertEqual(c_red.value, "red") class Color(enum.Enum): - RED = 'red' - WHITE = 'white' + RED = "red" + WHITE = "white" @dataclasses.dataclass class Container: @@ -95,18 +94,18 @@ class Container: c = Container(c_red) d = dataclasses.asdict(c) - self.assertIs(d['color'], c_red) + self.assertIs(d["color"], c_red) async def test_enum_03(self): c_red = await self.client.query_single('SELECT "red"') - c_red2 = await self.client.query_single('SELECT $0', c_red) + c_red2 = await self.client.query_single("SELECT $0", c_red) self.assertIs(c_red, c_red2) async def test_enum_04(self): enums = await self.client.query_single( - 'SELECT >$0', ['red', 'white'] + "SELECT >$0", ["red", "white"] ) enums2 = await self.client.query_single( - 'SELECT >$0', enums + "SELECT >$0", enums ) self.assertEqual(enums, enums2) diff --git a/tests/test_errors.py b/tests/test_errors.py index f76efa28..d285d1a9 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -25,37 +25,36 @@ class TestErrors(unittest.TestCase): - def test_errors_1(self): new = base_errors.EdgeDBError._from_code - e = new(0x_04_00_00_00, 'aa') + e = new(0x_04_00_00_00, "aa") self.assertIs(type(e), errors.QueryError) self.assertEqual(e.get_code(), 0x_04_00_00_00) - e = new(0x_04_01_00_00, 'aa') + e = new(0x_04_01_00_00, "aa") self.assertIs(type(e), errors.InvalidSyntaxError) self.assertEqual(e.get_code(), 0x_04_01_00_00) - e = new(0x_04_01_01_00, 'aa') + e = new(0x_04_01_01_00, "aa") self.assertIs(type(e), errors.EdgeQLSyntaxError) self.assertEqual(e.get_code(), 0x_04_01_01_00) - e = new(0x_04_01_01_FF, 'aa') + e = new(0x_04_01_01_FF, "aa") self.assertIs(type(e), errors.EdgeQLSyntaxError) self.assertEqual(e.get_code(), 0x_04_01_01_FF) - e = new(0x_04_01_FF_FF, 'aa') + e = new(0x_04_01_FF_FF, "aa") self.assertIs(type(e), errors.InvalidSyntaxError) self.assertEqual(e.get_code(), 0x_04_01_FF_FF) - e = new(0x_04_00_FF_FF, 'aa') + e = new(0x_04_00_FF_FF, "aa") self.assertIs(type(e), errors.QueryError) self.assertEqual(e.get_code(), 0x_04_00_FF_FF) def test_errors_2(self): new = base_errors.EdgeDBError._from_code - e = new(0x_F9_1E_FF_F1, 'aa') + e = new(0x_F9_1E_FF_F1, "aa") self.assertEqual(e.get_code(), 0x_F9_1E_FF_F1) self.assertIs(type(e), errors.EdgeDBError) diff --git a/tests/test_globals.py b/tests/test_globals.py index f8badbd4..a0dd705d 100644 --- a/tests/test_globals.py +++ b/tests/test_globals.py @@ -22,13 +22,13 @@ class TestGlobals(tb.AsyncQueryTestCase): - async def test_globals_01(self): db = self.client if db.is_proto_lt_1_0: self.skipTest("Global is added in EdgeDB 2.0") - await db.execute(''' + await db.execute( + """ CREATE GLOBAL glob -> str; CREATE REQUIRED GLOBAL req_glob -> str { SET default := '!'; @@ -36,29 +36,30 @@ async def test_globals_01(self): CREATE GLOBAL def_glob -> str { SET default := '!'; }; - ''') + """ + ) - async with db.with_globals(glob='test') as gdb: - x = await gdb.query_single('select global glob') - self.assertEqual(x, 'test') + async with db.with_globals(glob="test") as gdb: + x = await gdb.query_single("select global glob") + self.assertEqual(x, "test") - x = await gdb.query_single('select global req_glob') - self.assertEqual(x, '!') + x = await gdb.query_single("select global req_glob") + self.assertEqual(x, "!") - x = await gdb.query_single('select global def_glob') - self.assertEqual(x, '!') + x = await gdb.query_single("select global def_glob") + self.assertEqual(x, "!") - async with db.with_globals(req_glob='test') as gdb: - x = await gdb.query_single('select global req_glob') - self.assertEqual(x, 'test') + async with db.with_globals(req_glob="test") as gdb: + x = await gdb.query_single("select global req_glob") + self.assertEqual(x, "test") - async with db.with_globals(def_glob='test') as gdb: - x = await gdb.query_single('select global def_glob') - self.assertEqual(x, 'test') + async with db.with_globals(def_glob="test") as gdb: + x = await gdb.query_single("select global def_glob") + self.assertEqual(x, "test") # Setting def_glob explicitly to None should override async with db.with_globals(def_glob=None) as gdb: - x = await gdb.query_single('select global def_glob') + x = await gdb.query_single("select global def_glob") self.assertEqual(x, None) async def test_client_state_mismatch(self): @@ -66,17 +67,17 @@ async def test_client_state_mismatch(self): if db.is_proto_lt_1_0: self.skipTest("State over protocol is added in EdgeDB 2.0") - await db.execute('create global mglob -> int32') + await db.execute("create global mglob -> int32") c = self.make_test_client(database=self.get_database_name()) c = c.with_globals(mglob=42) - self.assertEqual(await c.query_single('select global mglob'), 42) + self.assertEqual(await c.query_single("select global mglob"), 42) - await db.execute('create global mglob2 -> str') - self.assertEqual(await c.query_single('select global mglob'), 42) + await db.execute("create global mglob2 -> str") + self.assertEqual(await c.query_single("select global mglob"), 42) - await db.execute('alter global mglob set type str reset to default') + await db.execute("alter global mglob set type str reset to default") with self.assertRaises(errors.InvalidArgumentError): - await c.query_single('select global mglob') + await c.query_single("select global mglob") await c.aclose() diff --git a/tests/test_memory.py b/tests/test_memory.py index 63c032e2..4dee206d 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -20,13 +20,13 @@ class TestConfigMemory(tb.SyncQueryTestCase): - async def test_config_memory_01(self): if ( self.client.query_required_single( "select exists " "(select schema::Type filter .name = 'cfg::memory')" - ) is False + ) + is False ): self.skipTest("feature not implemented") @@ -44,29 +44,33 @@ async def test_config_memory_01(self): # Test that ConfigMemory.__str__ formats the # same as - mem_tuples = self.client.query(''' + mem_tuples = self.client.query( + """ WITH args := array_unpack(>$0) SELECT ( args, args, args ); - ''', mem_strs) + """, + mem_strs, + ) mem_vals = [t[0] for t in mem_tuples] # Test encode/decode roundtrip - roundtrip = self.client.query(''' + roundtrip = self.client.query( + """ WITH args := array_unpack(>$0) SELECT args; - ''', mem_vals) + """, + mem_vals, + ) self.assertEqual( - [str(t[0]) for t in mem_tuples], - [t[1] for t in mem_tuples] + [str(t[0]) for t in mem_tuples], [t[1] for t in mem_tuples] ) self.assertEqual( - [t[0].as_bytes() for t in mem_tuples], - [t[2] for t in mem_tuples] + [t[0].as_bytes() for t in mem_tuples], [t[2] for t in mem_tuples] ) self.assertEqual(list(roundtrip), mem_vals) diff --git a/tests/test_proto.py b/tests/test_proto.py index 48be7f15..9510ab6e 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -24,52 +24,58 @@ class TestProto(tb.SyncQueryTestCase): - def test_json(self): self.assertEqual( - self.client.query_json('SELECT {"aaa", "bbb"}'), - '["aaa", "bbb"]') + self.client.query_json('SELECT {"aaa", "bbb"}'), '["aaa", "bbb"]' + ) # std::datetime is now in range of Python datetime, # so another way to trigger codec failure is needed. - @unittest.skip(""" + @unittest.skip( + """ std::datetime is now in range of Python date, so another way to trigger codec failure is needed. - """) + """ + ) async def test_proto_codec_error_recovery_01(self): for _ in range(5): # execute a few times for OE with self.assertRaisesRegex( - edgedb.ClientError, - "unable to decode data to Python objects"): + edgedb.ClientError, "unable to decode data to Python objects" + ): # Python dattime.Date object can't represent this date, so # we know that the codec will fail. # The test will be rewritten once it's possible to override # default codecs. - self.client.query(""" + self.client.query( + """ SELECT cal::to_local_date('0001-01-01 BC', 'YYYY-MM-DD AD'); - """) + """ + ) # The protocol, though, shouldn't be in some inconsistent # state; it should allow new queries to execute successfully. self.assertEqual( - self.client.query('SELECT {"aaa", "bbb"}'), - ['aaa', 'bbb']) + self.client.query('SELECT {"aaa", "bbb"}'), ["aaa", "bbb"] + ) - @unittest.skip(""" + @unittest.skip( + """ std::date is now in range of Python date, so another way to trigger codec failure is needed. - """) + """ + ) async def test_proto_codec_error_recovery_02(self): for _ in range(5): # execute a few times for OE with self.assertRaisesRegex( - edgedb.ClientError, - "unable to decode data to Python objects"): + edgedb.ClientError, "unable to decode data to Python objects" + ): # Python dattime.Date object can't represent this date, so # we know that the codec will fail. # The test will be rewritten once it's possible to override # default codecs. - self.client.query(r""" + self.client.query( + r""" SELECT cal::to_local_date( { '2010-01-01 AD', @@ -81,10 +87,11 @@ async def test_proto_codec_error_recovery_02(self): }, 'YYYY-MM-DD AD' ); - """) + """ + ) # The protocol, though, shouldn't be in some inconsistent # state; it should allow new queries to execute successfully. self.assertEqual( - self.client.query('SELECT {"aaa", "bbb"}'), - ['aaa', 'bbb']) + self.client.query('SELECT {"aaa", "bbb"}'), ["aaa", "bbb"] + ) diff --git a/tests/test_scram.py b/tests/test_scram.py index ff3aeb3c..6250407d 100644 --- a/tests/test_scram.py +++ b/tests/test_scram.py @@ -24,27 +24,27 @@ class TestSCRAM(unittest.TestCase): - def test_scram_sha_256_rfc_example(self): # Test SCRAM-SHA-256 against an example in RFC 7677 - username = 'user' - password = 'pencil' - client_nonce = 'rOprNGfwEbeRWgbNEkqO' - server_nonce = 'rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0' - salt = 'W22ZaJ0SNY7soEsUEjb6gQ==' - channel_binding = 'biws' + username = "user" + password = "pencil" + client_nonce = "rOprNGfwEbeRWgbNEkqO" + server_nonce = "rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0" + salt = "W22ZaJ0SNY7soEsUEjb6gQ==" + channel_binding = "biws" iterations = 4096 - client_first = f'n={username},r={client_nonce}' - server_first = f'r={server_nonce},s={salt},i={iterations}' - client_final = f'c={channel_binding},r={server_nonce}' + client_first = f"n={username},r={client_nonce}" + server_first = f"r={server_nonce},s={salt},i={iterations}" + client_final = f"c={channel_binding},r={server_nonce}" - AuthMessage = f'{client_first},{server_first},{client_final}'.encode() + AuthMessage = f"{client_first},{server_first},{client_final}".encode() SaltedPassword = scram.get_salted_password( - scram.saslprep(password).encode('utf-8'), + scram.saslprep(password).encode("utf-8"), base64.b64decode(salt), - iterations) + iterations, + ) ClientKey = scram.get_client_key(SaltedPassword) ServerKey = scram.get_server_key(SaltedPassword) StoredKey = scram.H(ClientKey) @@ -52,15 +52,19 @@ def test_scram_sha_256_rfc_example(self): ClientProof = scram.XOR(ClientKey, ClientSignature) ServerProof = scram.HMAC(ServerKey, AuthMessage) - self.assertEqual(scram.B64(ClientProof), - 'dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=') + self.assertEqual( + scram.B64(ClientProof), + "dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=", + ) - self.assertEqual(scram.B64(ServerProof), - '6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=') + self.assertEqual( + scram.B64(ServerProof), + "6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=", + ) def test_scram_sha_256_verifier(self): - salt = 'W22ZaJ0SNY7soEsUEjb6gQ==' - password = 'pencil' + salt = "W22ZaJ0SNY7soEsUEjb6gQ==" + password = "pencil" v = scram.build_verifier( password, @@ -68,15 +72,16 @@ def test_scram_sha_256_verifier(self): iterations=4096, ) - stored_key = 'WG5d8oPm3OtcPnkdi4Uo7BkeZkBFzpcXkuLmtbsT4qY=' - server_key = 'wfPLwcE6nTWhTAmQ7tl2KeoiWGPlZqQxSrmfPwDl2dU=' + stored_key = "WG5d8oPm3OtcPnkdi4Uo7BkeZkBFzpcXkuLmtbsT4qY=" + server_key = "wfPLwcE6nTWhTAmQ7tl2KeoiWGPlZqQxSrmfPwDl2dU=" self.assertEqual( - v, f'SCRAM-SHA-256$4096:{salt}${stored_key}:{server_key}') + v, f"SCRAM-SHA-256$4096:{salt}${stored_key}:{server_key}" + ) parsed = scram.parse_verifier(v) - self.assertEqual(parsed.mechanism, 'SCRAM-SHA-256') + self.assertEqual(parsed.mechanism, "SCRAM-SHA-256") self.assertEqual(parsed.iterations, 4096) self.assertEqual(parsed.salt, base64.b64decode(salt)) self.assertEqual(parsed.stored_key, base64.b64decode(stored_key)) diff --git a/tests/test_sourcecode.py b/tests/test_sourcecode.py index 1836c2d0..12bd0a9c 100644 --- a/tests/test_sourcecode.py +++ b/tests/test_sourcecode.py @@ -28,19 +28,18 @@ def find_edgedb_root(): class TestFlake8(unittest.TestCase): - def test_flake8(self): edgepath = find_edgedb_root() - config_path = os.path.join(edgepath, '.flake8') + config_path = os.path.join(edgepath, ".flake8") if not os.path.exists(config_path): - raise RuntimeError('could not locate .flake8 file') + raise RuntimeError("could not locate .flake8 file") try: import flake8 # NoQA except ImportError: - raise unittest.SkipTest('flake8 moudule is missing') + raise unittest.SkipTest("flake8 moudule is missing") - for subdir in ['edgedb', 'tests']: # ignore any top-level test files + for subdir in ["edgedb", "tests"]: # ignore any top-level test files try: subprocess.run( [ @@ -54,8 +53,10 @@ def test_flake8(self): check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - cwd=edgepath) + cwd=edgepath, + ) except subprocess.CalledProcessError as ex: output = ex.output.decode() raise AssertionError( - f'flake8 validation failed:\n{output}') from None + f"flake8 validation failed:\n{output}" + ) from None diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py index 8dd35c3c..b20a16fe 100644 --- a/tests/test_sync_query.py +++ b/tests/test_sync_query.py @@ -32,84 +32,85 @@ class TestSyncQuery(tb.SyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE TYPE test::Tmp { CREATE REQUIRED PROPERTY tmp -> std::str; }; CREATE SCALAR TYPE MyEnum EXTENDING enum<"A", "B">; - ''' + """ - TEARDOWN = ''' + TEARDOWN = """ DROP TYPE test::Tmp; - ''' + """ def test_sync_parse_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.client.query('select syntax error') + self.client.query("select syntax error") with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.client.query('select syntax error') + self.client.query("select syntax error") - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): - self.client.query('select (') + with self.assertRaisesRegex( + edgedb.EdgeQLSyntaxError, "Unexpected end of line" + ): + self.client.query("select (") - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): - self.client.query_json('select (') + with self.assertRaisesRegex( + edgedb.EdgeQLSyntaxError, "Unexpected end of line" + ): + self.client.query_json("select (") for _ in range(10): self.assertEqual( - self.client.query('select 1;'), - edgedb.Set((1,))) + self.client.query("select 1;"), edgedb.Set((1,)) + ) self.assertFalse(self.client.connection.is_closed()) def test_sync_parse_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.client.execute('select syntax error') + self.client.execute("select syntax error") with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.client.execute('select syntax error') + self.client.execute("select syntax error") for _ in range(10): - self.client.execute('select 1; select 2;'), + self.client.execute("select 1; select 2;"), def test_sync_exec_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - self.client.query('select 1 / 0;') + self.client.query("select 1 / 0;") with self.assertRaises(edgedb.DivisionByZeroError): - self.client.query('select 1 / 0;') + self.client.query("select 1 / 0;") for _ in range(10): self.assertEqual( - self.client.query('select 1;'), - edgedb.Set((1,))) + self.client.query("select 1;"), edgedb.Set((1,)) + ) def test_sync_exec_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - self.client.execute('select 1 / 0;') + self.client.execute("select 1 / 0;") with self.assertRaises(edgedb.DivisionByZeroError): - self.client.execute('select 1 / 0;') + self.client.execute("select 1 / 0;") for _ in range(10): - self.client.execute('select 1;') + self.client.execute("select 1;") def test_sync_exec_error_recover_03(self): - query = 'select 10 // $0;' + query = "select 10 // $0;" for i in [1, 2, 0, 3, 1, 0, 1]: if i: self.assertEqual( - self.client.query(query, i), - edgedb.Set([10 // i])) + self.client.query(query, i), edgedb.Set([10 // i]) + ) else: with self.assertRaises(edgedb.DivisionByZeroError): self.client.query(query, i) @@ -117,17 +118,15 @@ def test_sync_exec_error_recover_03(self): def test_sync_exec_error_recover_04(self): for i in [1, 2, 0, 3, 1, 0, 1]: if i: - self.client.execute(f'select 10 // {i};') + self.client.execute(f"select 10 // {i};") else: with self.assertRaises(edgedb.DivisionByZeroError): - self.client.query(f'select 10 // {i};') + self.client.query(f"select 10 // {i};") def test_sync_exec_error_recover_05(self): with self.assertRaises(edgedb.DivisionByZeroError): - self.client.execute(f'select 1 / 0') - self.assertEqual( - self.client.query('SELECT "HELLO"'), - ["HELLO"]) + self.client.execute("select 1 / 0") + self.assertEqual(self.client.query('SELECT "HELLO"'), ["HELLO"]) def test_sync_query_single_01(self): res = self.client.query_single("SELECT 1") @@ -141,195 +140,212 @@ def test_sync_query_single_01(self): self.client.query_required_single("SELECT {}") def test_sync_query_single_command_01(self): - r = self.client.query(''' + r = self.client.query( + """ CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; }; - ''') + """ + ) self.assertEqual(r, []) - r = self.client.query(''' + r = self.client.query( + """ DROP TYPE test::server_query_single_command_01; - ''') + """ + ) self.assertEqual(r, []) - r = self.client.query(''' + r = self.client.query( + """ CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; }; - ''') + """ + ) self.assertEqual(r, []) - r = self.client.query(''' + r = self.client.query( + """ DROP TYPE test::server_query_single_command_01; - ''') + """ + ) self.assertEqual(r, []) - r = self.client.query_json(''' + r = self.client.query_json( + """ CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; }; - ''') - self.assertEqual(r, '[]') + """ + ) + self.assertEqual(r, "[]") - r = self.client.query_json(''' + r = self.client.query_json( + """ DROP TYPE test::server_query_single_command_01; - ''') - self.assertEqual(r, '[]') + """ + ) + self.assertEqual(r, "[]") self.assertTrue( - self.client.connection._get_last_status().startswith('DROP') + self.client.connection._get_last_status().startswith("DROP") ) def test_sync_query_no_return(self): with self.assertRaisesRegex( - edgedb.InterfaceError, - r'cannot be executed with query_required_single\(\).*' - r'not return'): - self.client.query_required_single('create type Foo123') + edgedb.InterfaceError, + r"cannot be executed with query_required_single\(\).*" + r"not return", + ): + self.client.query_required_single("create type Foo123") with self.assertRaisesRegex( - edgedb.InterfaceError, - r'cannot be executed with query_required_single_json\(\).*' - r'not return'): - self.client.query_required_single_json('create type Bar123') + edgedb.InterfaceError, + r"cannot be executed with query_required_single_json\(\).*" + r"not return", + ): + self.client.query_required_single_json("create type Bar123") def test_sync_basic_datatypes_01(self): for _ in range(10): - self.assertEqual( - self.client.query_single( - 'select ()'), - ()) + self.assertEqual(self.client.query_single("select ()"), ()) self.assertEqual( - self.client.query( - 'select (1,)'), - edgedb.Set([(1,)])) + self.client.query("select (1,)"), edgedb.Set([(1,)]) + ) self.assertEqual( - self.client.query_single( - 'select >[]'), - []) + self.client.query_single("select >[]"), [] + ) self.assertEqual( - self.client.query( - 'select ["a", "b"]'), - edgedb.Set([["a", "b"]])) + self.client.query('select ["a", "b"]'), + edgedb.Set([["a", "b"]]), + ) self.assertEqual( - self.client.query(''' + self.client.query( + """ SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; - '''), - edgedb.Set([ - edgedb.NamedTuple(a=42, world=("hello", 32)), - edgedb.NamedTuple(a=1, world=("yo", 10)), - ])) + """ + ), + edgedb.Set( + [ + edgedb.NamedTuple(a=42, world=("hello", 32)), + edgedb.NamedTuple(a=1, world=("yo", 10)), + ] + ), + ) with self.assertRaisesRegex( - edgedb.InterfaceError, - r'query cannot be executed with query_single\('): - self.client.query_single('SELECT {1, 2}') + edgedb.InterfaceError, + r"query cannot be executed with query_single\(", + ): + self.client.query_single("SELECT {1, 2}") - with self.assertRaisesRegex(edgedb.NoDataError, - r'\bquery_required_single_json\('): - self.client.query_required_single_json('SELECT {}') + with self.assertRaisesRegex( + edgedb.NoDataError, r"\bquery_required_single_json\(" + ): + self.client.query_required_single_json("SELECT {}") def test_sync_basic_datatypes_02(self): self.assertEqual( - self.client.query( - r'''select [b"\x00a", b"b", b'', b'\na']'''), - edgedb.Set([[b"\x00a", b"b", b'', b'\na']])) + self.client.query(r"""select [b"\x00a", b"b", b'', b'\na']"""), + edgedb.Set([[b"\x00a", b"b", b"", b"\na"]]), + ) self.assertEqual( - self.client.query( - r'select $0', b'he\x00llo'), - edgedb.Set([b'he\x00llo'])) + self.client.query(r"select $0", b"he\x00llo"), + edgedb.Set([b"he\x00llo"]), + ) def test_sync_basic_datatypes_03(self): for _ in range(10): - self.assertEqual( - self.client.query_json( - 'select ()'), - '[[]]') + self.assertEqual(self.client.query_json("select ()"), "[[]]") - self.assertEqual( - self.client.query_json( - 'select (1,)'), - '[[1]]') + self.assertEqual(self.client.query_json("select (1,)"), "[[1]]") self.assertEqual( - self.client.query_json( - 'select >[]'), - '[[]]') + self.client.query_json("select >[]"), "[[]]" + ) self.assertEqual( - json.loads( - self.client.query_json( - 'select ["a", "b"]')), - [["a", "b"]]) + json.loads(self.client.query_json('select ["a", "b"]')), + [["a", "b"]], + ) self.assertEqual( - json.loads( - self.client.query_single_json( - 'select ["a", "b"]')), - ["a", "b"]) + json.loads(self.client.query_single_json('select ["a", "b"]')), + ["a", "b"], + ) self.assertEqual( json.loads( - self.client.query_json(''' + self.client.query_json( + """ SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; - ''')), + """ + ) + ), [ {"a": 42, "world": ["hello", 32]}, - {"a": 1, "world": ["yo", 10]} - ]) + {"a": 1, "world": ["yo", 10]}, + ], + ) self.assertEqual( - json.loads( - self.client.query_json('SELECT {1, 2}')), - [1, 2]) + json.loads(self.client.query_json("SELECT {1, 2}")), [1, 2] + ) self.assertEqual( - json.loads(self.client.query_json('SELECT {}')), - []) + json.loads(self.client.query_json("SELECT {}")), [] + ) with self.assertRaises(edgedb.NoDataError): - self.client.query_required_single_json('SELECT {}') + self.client.query_required_single_json("SELECT {}") self.assertEqual( - json.loads(self.client.query_single_json('SELECT {}')), - None + json.loads(self.client.query_single_json("SELECT {}")), + None, ) def test_sync_args_01(self): self.assertEqual( self.client.query( - 'select (>$foo)[0] ++ (>$bar)[0];', - foo=['aaa'], bar=['bbb']), - edgedb.Set(('aaabbb',))) + "select (>$foo)[0] ++ (>$bar)[0];", + foo=["aaa"], + bar=["bbb"], + ), + edgedb.Set(("aaabbb",)), + ) def test_sync_args_02(self): self.assertEqual( self.client.query( - 'select (>$0)[0] ++ (>$1)[0];', - ['aaa'], ['bbb']), - edgedb.Set(('aaabbb',))) + "select (>$0)[0] ++ (>$1)[0];", + ["aaa"], + ["bbb"], + ), + edgedb.Set(("aaabbb",)), + ) def test_sync_args_03(self): - with self.assertRaisesRegex(edgedb.QueryError, r'missing \$0'): - self.client.query('select $1;') + with self.assertRaisesRegex(edgedb.QueryError, r"missing \$0"): + self.client.query("select $1;") - with self.assertRaisesRegex(edgedb.QueryError, r'missing \$1'): - self.client.query('select $0 + $2;') + with self.assertRaisesRegex(edgedb.QueryError, r"missing \$1"): + self.client.query("select $0 + $2;") - with self.assertRaisesRegex(edgedb.QueryError, - 'combine positional and named parameters'): - self.client.query('select $0 + $bar;') + with self.assertRaisesRegex( + edgedb.QueryError, "combine positional and named parameters" + ): + self.client.query("select $0 + $bar;") def test_sync_args_04(self): aware_datetime = datetime.datetime.now(datetime.timezone.utc) @@ -340,112 +356,113 @@ def test_sync_args_04(self): aware_time = datetime.time(hour=11, tzinfo=datetime.timezone.utc) self.assertEqual( - self.client.query_single( - 'select $0;', - aware_datetime), - aware_datetime) + self.client.query_single("select $0;", aware_datetime), + aware_datetime, + ) self.assertEqual( self.client.query_single( - 'select $0;', - naive_datetime), - naive_datetime) + "select $0;", naive_datetime + ), + naive_datetime, + ) self.assertEqual( - self.client.query_single( - 'select $0;', - date), - date) + self.client.query_single("select $0;", date), date + ) self.assertEqual( self.client.query_single( - 'select $0;', - naive_time), - naive_time) + "select $0;", naive_time + ), + naive_time, + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'a timezone-aware.*expected'): - self.client.query_single( - 'select $0;', - naive_datetime) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, r"a timezone-aware.*expected" + ): + self.client.query_single("select $0;", naive_datetime) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'a naive time object.*expected'): - self.client.query_single( - 'select $0;', - aware_time) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, r"a naive time object.*expected" + ): + self.client.query_single("select $0;", aware_time) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'a naive datetime object.*expected'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, r"a naive datetime object.*expected" + ): self.client.query_single( - 'select $0;', - aware_datetime) + "select $0;", aware_datetime + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'datetime.datetime object was expected'): - self.client.query_single( - 'select $0;', - date) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, + r"datetime.datetime object was expected", + ): + self.client.query_single("select $0;", date) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - r'datetime.datetime object was expected'): - self.client.query_single( - 'select $0;', - date) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, + r"datetime.datetime object was expected", + ): + self.client.query_single("select $0;", date) def test_sync_mismatched_args_01(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'a'} arguments, " - "got {'[bc]', '[bc]'}, " - r"missed {'a'}, extra {'[bc]', '[bc]'}"): - + edgedb.QueryArgumentError, + r"expected {'a'} arguments, " + "got {'[bc]', '[bc]'}, " + r"missed {'a'}, extra {'[bc]', '[bc]'}", + ): self.client.query("""SELECT $a;""", b=1, c=2) def test_sync_mismatched_args_02(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'[ab]', '[ab]'} arguments, " - r"got {'[acd]', '[acd]', '[acd]'}, " - r"missed {'b'}, extra {'[cd]', '[cd]'}"): - - self.client.query(""" + edgedb.QueryArgumentError, + r"expected {'[ab]', '[ab]'} arguments, " + r"got {'[acd]', '[acd]', '[acd]'}, " + r"missed {'b'}, extra {'[cd]', '[cd]'}", + ): + self.client.query( + """ SELECT $a + $b; - """, a=1, c=2, d=3) + """, + a=1, + c=2, + d=3, + ) def test_sync_mismatched_args_03(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - "expected {'a'} arguments, got {'b'}, " - "missed {'a'}, extra {'b'}"): - + edgedb.QueryArgumentError, + "expected {'a'} arguments, got {'b'}, " + "missed {'a'}, extra {'b'}", + ): self.client.query("""SELECT $a;""", b=1) def test_sync_mismatched_args_04(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'[ab]', '[ab]'} arguments, " - r"got {'a'}, " - r"missed {'b'}"): - + edgedb.QueryArgumentError, + r"expected {'[ab]', '[ab]'} arguments, " + r"got {'a'}, " + r"missed {'b'}", + ): self.client.query("""SELECT $a + $b;""", a=1) def test_sync_mismatched_args_05(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'a'} arguments, " - r"got {'[ab]', '[ab]'}, " - r"extra {'b'}"): - + edgedb.QueryArgumentError, + r"expected {'a'} arguments, " + r"got {'[ab]', '[ab]'}, " + r"extra {'b'}", + ): self.client.query("""SELECT $a;""", a=1, b=2) def test_sync_mismatched_args_06(self): with self.assertRaisesRegex( - edgedb.QueryArgumentError, - r"expected {'a'} arguments, " - r"got nothing, " - r"missed {'a'}"): - + edgedb.QueryArgumentError, + r"expected {'a'} arguments, " r"got nothing, " r"missed {'a'}", + ): self.client.query("""SELECT $a;""") def test_sync_mismatched_args_07(self): @@ -453,40 +470,43 @@ def test_sync_mismatched_args_07(self): edgedb.QueryArgumentError, "expected no named arguments", ): - self.client.query("""SELECT 42""", a=1, b=2) def test_sync_args_uuid_pack(self): obj = self.client.query_single( - 'select schema::Object {id, name} limit 1') + "select schema::Object {id, name} limit 1" + ) # Test that the custom UUID that our driver uses can be # passed back as a parameter. ot = self.client.query_single( - 'select schema::Object {name} filter .id=$id', - id=obj.id) + "select schema::Object {name} filter .id=$id", id=obj.id + ) self.assertEqual(obj.id, ot.id) self.assertEqual(obj.name, ot.name) # Test that a string UUID is acceptable. ot = self.client.query_single( - 'select schema::Object {name} filter .id=$id', - id=str(obj.id)) + "select schema::Object {name} filter .id=$id", id=str(obj.id) + ) self.assertEqual(obj.id, ot.id) self.assertEqual(obj.name, ot.name) # Test that a standard uuid.UUID is acceptable. ot = self.client.query_single( - 'select schema::Object {name} filter .id=$id', - id=uuid.UUID(bytes=obj.id.bytes)) + "select schema::Object {name} filter .id=$id", + id=uuid.UUID(bytes=obj.id.bytes), + ) self.assertEqual(obj.id, ot.id) self.assertEqual(obj.name, ot.name) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'invalid UUID.*length must be'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "invalid UUID.*length must be" + ): self.client.query( - 'select schema::Object {name} filter .id=$id', - id='asdasas') + "select schema::Object {name} filter .id=$id", + id="asdasas", + ) def test_sync_args_bigint_basic(self): testar = [ @@ -546,97 +566,92 @@ def test_sync_args_bigint_basic(self): ] for _ in range(500): - num = '' + num = "" for _ in range(random.randint(1, 50)): num += random.choice("0123456789") testar.append(int(num)) for _ in range(500): - num = '' + num = "" for _ in range(random.randint(1, 50)): num += random.choice("0000000012") testar.append(int(num)) val = self.client.query_single( - 'select >$arg', - arg=testar) + "select >$arg", arg=testar + ) self.assertEqual(testar, val) def test_sync_args_bigint_pack(self): - val = self.client.query_single( - 'select $arg', - arg=10) + val = self.client.query_single("select $arg", arg=10) self.assertEqual(val, 10) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - self.client.query( - 'select $arg', - arg='bad int') + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + self.client.query("select $arg", arg="bad int") - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - self.client.query( - 'select $arg', - arg=10.11) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + self.client.query("select $arg", arg=10.11) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): self.client.query( - 'select $arg', - arg=decimal.Decimal('10.0')) + "select $arg", arg=decimal.Decimal("10.0") + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): self.client.query( - 'select $arg', - arg=decimal.Decimal('10.11')) + "select $arg", arg=decimal.Decimal("10.11") + ) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - self.client.query( - 'select $arg', - arg='10') + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + self.client.query("select $arg", arg="10") - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): self.client.query_single( - 'select $arg', - arg=decimal.Decimal('10')) + "select $arg", arg=decimal.Decimal("10") + ) + + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): class IntLike: def __int__(self): return 10 - self.client.query_single( - 'select $arg', - arg=IntLike()) + self.client.query_single("select $arg", arg=IntLike()) def test_sync_args_intlike(self): class IntLike: def __int__(self): return 10 - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + self.client.query_single("select $arg", arg=IntLike()) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + self.client.query_single("select $arg", arg=IntLike()) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected an int'): - self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected an int" + ): + self.client.query_single("select $arg", arg=IntLike()) def test_sync_args_decimal(self): class IntLike: @@ -644,28 +659,28 @@ def __int__(self): return 10 val = self.client.query_single( - 'select $0', decimal.Decimal("10.0") + "select $0", decimal.Decimal("10.0") ) self.assertEqual(val, 10) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected a Decimal or an int'): - self.client.query_single( - 'select $arg', - arg=IntLike()) + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected a Decimal or an int" + ): + self.client.query_single("select $arg", arg=IntLike()) - with self.assertRaisesRegex(edgedb.InvalidArgumentError, - 'expected a Decimal or an int'): - self.client.query_single( - 'select $arg', - arg="10.2") + with self.assertRaisesRegex( + edgedb.InvalidArgumentError, "expected a Decimal or an int" + ): + self.client.query_single("select $arg", arg="10.2") def test_sync_wait_cancel_01(self): - underscored_lock = self.client.query_single(""" + underscored_lock = self.client.query_single( + """ SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_advisory_lock' ) - """) + """ + ) if not underscored_lock: self.skipTest("No sys::_advisory_lock function") @@ -676,33 +691,37 @@ def test_sync_wait_cancel_01(self): client = self.client.with_retry_options( edgedb.RetryOptions(attempts=1) ) - client2 = self.make_test_client( - database=self.client.dbname - ).with_retry_options( - edgedb.RetryOptions(attempts=1) - ).ensure_connected() + client2 = ( + self.make_test_client(database=self.client.dbname) + .with_retry_options(edgedb.RetryOptions(attempts=1)) + .ensure_connected() + ) for tx in client.transaction(): with tx: - self.assertTrue(tx.query_single( - 'select sys::_advisory_lock($0)', - lock_key)) + self.assertTrue( + tx.query_single( + "select sys::_advisory_lock($0)", lock_key + ) + ) evt = threading.Event() def exec_to_fail(): - with self.assertRaises(( - edgedb.ClientConnectionClosedError, - edgedb.ClientConnectionFailedError, - )): + with self.assertRaises( + ( + edgedb.ClientConnectionClosedError, + edgedb.ClientConnectionFailedError, + ) + ): for tx2 in client2.transaction(): with tx2: # start the lazy transaction - tx2.query('SELECT 42;') + tx2.query("SELECT 42;") evt.set() tx2.query( - 'select sys::_advisory_lock($0)', + "select sys::_advisory_lock($0)", lock_key, ) @@ -725,12 +744,14 @@ def exec_to_fail(): t.join() self.assertEqual( tx.query( - 'select sys::_advisory_unlock($0)', - lock_key), - [True]) + "select sys::_advisory_unlock($0)", lock_key + ), + [True], + ) def test_empty_set_unpack(self): - self.client.query_single(''' + self.client.query_single( + """ select schema::Function { name, params: { @@ -740,43 +761,40 @@ def test_empty_set_unpack(self): } filter .name = 'std::str_repeat' limit 1 - ''') + """ + ) def test_enum_argument_01(self): - A = self.client.query_single('SELECT $0', 'A') - self.assertEqual(str(A), 'A') + A = self.client.query_single("SELECT $0", "A") + self.assertEqual(str(A), "A") with self.assertRaisesRegex( - edgedb.InvalidValueError, 'invalid input value for enum' + edgedb.InvalidValueError, "invalid input value for enum" ): for tx in self.client.transaction(): with tx: - tx.query_single('SELECT $0', 'Oups') + tx.query_single("SELECT $0", "Oups") - self.assertEqual( - self.client.query_single('SELECT $0', 'A'), - A) + self.assertEqual(self.client.query_single("SELECT $0", "A"), A) - self.assertEqual( - self.client.query_single('SELECT $0', A), - A) + self.assertEqual(self.client.query_single("SELECT $0", A), A) with self.assertRaisesRegex( - edgedb.InvalidValueError, 'invalid input value for enum' + edgedb.InvalidValueError, "invalid input value for enum" ): for tx in self.client.transaction(): with tx: - tx.query_single('SELECT $0', 'Oups') + tx.query_single("SELECT $0", "Oups") with self.assertRaisesRegex( - edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue' + edgedb.InvalidArgumentError, "a str or edgedb.EnumValue" ): - self.client.query_single('SELECT $0', 123) + self.client.query_single("SELECT $0", 123) def test_json(self): self.assertEqual( - self.client.query_json('SELECT {"aaa", "bbb"}'), - '["aaa", "bbb"]') + self.client.query_json('SELECT {"aaa", "bbb"}'), '["aaa", "bbb"]' + ) def test_json_elements(self): self.client.ensure_connected() @@ -797,33 +815,33 @@ def test_json_elements(self): ) ) ) - self.assertEqual( - result, - edgedb.Set(['"aaa"', '"bbb"'])) + self.assertEqual(result, edgedb.Set(['"aaa"', '"bbb"'])) def _test_sync_cancel_01(self): # TODO(fantix): enable when command_timeout is implemented - has_sleep = self.client.query_single(""" + has_sleep = self.client.query_single( + """ SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_sleep' ) - """) + """ + ) if not has_sleep: self.skipTest("No sys::_sleep function") client = self.make_test_client(database=self.client.dbname) try: - self.assertEqual(client.query_single('SELECT 1'), 1) + self.assertEqual(client.query_single("SELECT 1"), 1) protocol_before = client._impl._holders[0]._con._protocol with self.assertRaises(edgedb.InterfaceError): client.with_timeout_options(command_timeout=0.1).query_single( - 'SELECT sys::_sleep(10)' + "SELECT sys::_sleep(10)" ) - client.query('SELECT 2') + client.query("SELECT 2") protocol_after = client._impl._holders[0]._con._protocol self.assertIsNot( @@ -843,43 +861,51 @@ def on_log(con, msg): con.add_log_listener(on_log) try: self.client.query( - 'configure system set __internal_restart := true;' + "configure system set __internal_restart := true;" ) # self.con.query('SELECT 1') finally: con.remove_log_listener(on_log) for msg in msgs: - if (msg.get_severity_name() == 'NOTICE' and - 'server restart is required' in str(msg)): + if ( + msg.get_severity_name() == "NOTICE" + and "server restart is required" in str(msg) + ): break else: - raise AssertionError('a notice message was not delivered') + raise AssertionError("a notice message was not delivered") def test_sync_banned_transaction(self): with self.assertRaisesRegex( edgedb.CapabilityError, - r'cannot execute transaction control commands', + r"cannot execute transaction control commands", ): - self.client.query('start transaction') + self.client.query("start transaction") with self.assertRaisesRegex( edgedb.CapabilityError, - r'cannot execute transaction control commands', + r"cannot execute transaction control commands", ): - self.client.execute('start transaction') + self.client.execute("start transaction") def test_transaction_state(self): with self.assertRaisesRegex(edgedb.QueryError, "cannot assign to id"): for tx in self.client.transaction(): with tx: - tx.execute(''' + tx.execute( + """ INSERT test::Tmp { id := $0, tmp := '' } - ''', uuid.uuid4()) + """, + uuid.uuid4(), + ) client = self.client.with_config(allow_user_specified_id=True) for tx in client.transaction(): with tx: - tx.execute(''' + tx.execute( + """ INSERT test::Tmp { id := $0, tmp := '' } - ''', uuid.uuid4()) + """, + uuid.uuid4(), + ) diff --git a/tests/test_sync_retry.py b/tests/test_sync_retry.py index 831f0964..fbfcf248 100644 --- a/tests/test_sync_retry.py +++ b/tests/test_sync_retry.py @@ -45,8 +45,7 @@ def ready(self): class TestSyncRetry(tb.SyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE TYPE test::Counter EXTENDING std::Object { CREATE PROPERTY name -> std::str { CREATE CONSTRAINT std::exclusive; @@ -55,42 +54,50 @@ class TestSyncRetry(tb.SyncQueryTestCase): SET default := 0; }; }; - ''' + """ - TEARDOWN = ''' + TEARDOWN = """ DROP TYPE test::Counter; - ''' + """ def test_sync_retry_01(self): for tx in self.client.transaction(): with tx: - tx.execute(''' + tx.execute( + """ INSERT test::Counter { name := 'counter1' }; - ''') + """ + ) def test_sync_retry_02(self): with self.assertRaises(ZeroDivisionError): for tx in self.client.transaction(): with tx: - tx.execute(''' + tx.execute( + """ INSERT test::Counter { name := 'counter_retry_02' }; - ''') + """ + ) 1 / 0 with self.assertRaises(edgedb.NoDataError): - self.client.query_required_single(''' + self.client.query_required_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_02' - ''') + """ + ) self.assertEqual( - self.client.query_single(''' + self.client.query_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_02' - '''), - None + """ + ), + None, ) def test_sync_retry_begin(self): @@ -112,22 +119,28 @@ def cleanup(): with self.assertRaises(errors.BackendUnavailableError): for tx in self.client.transaction(): with tx: - tx.execute(''' + tx.execute( + """ INSERT test::Counter { name := 'counter_retry_begin' }; - ''') + """ + ) with self.assertRaises(edgedb.NoDataError): - self.client.query_required_single(''' + self.client.query_required_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_begin' - ''') + """ + ) self.assertEqual( - self.client.query_single(''' + self.client.query_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_begin' - '''), - None + """ + ), + None, ) def recover_after_first_error(*_, **__): @@ -139,28 +152,32 @@ def recover_after_first_error(*_, **__): for tx in self.client.transaction(): with tx: - tx.execute(''' + tx.execute( + """ INSERT test::Counter { name := 'counter_retry_begin' }; - ''') + """ + ) self.assertEqual(_start.call_count, call_count + 1) - self.client.query_single(''' + self.client.query_single( + """ SELECT test::Counter FILTER .name = 'counter_retry_begin' - ''') + """ + ) def test_sync_retry_conflict(self): - self.execute_conflict('counter2') + self.execute_conflict("counter2") def test_sync_conflict_no_retry(self): with self.assertRaises(edgedb.TransactionSerializationError): self.execute_conflict( - 'counter3', - RetryOptions(attempts=1, backoff=edgedb.default_backoff) + "counter3", + RetryOptions(attempts=1, backoff=edgedb.default_backoff), ) - def execute_conflict(self, name='counter2', options=None): + def execute_conflict(self, name="counter2", options=None): con_args = self.get_connect_args().copy() con_args.update(database=self.get_database_name()) client2 = edgedb.create_client(**con_args) @@ -189,7 +206,8 @@ def transaction1(client): barrier.ready() lock.acquire() - res = tx.query_single(''' + res = tx.query_single( + """ SELECT ( INSERT test::Counter { name := $name, @@ -200,7 +218,9 @@ def transaction1(client): SET { value := .value + 1 } ) ).value - ''', name=name) + """, + name=name, + ) lock.release() return res @@ -241,8 +261,9 @@ def test_sync_transaction_interface_errors(self): for tx in self.client.transaction(): tx.start() - with self.assertRaisesRegex(edgedb.InterfaceError, - r'.*Use `with transaction:`'): + with self.assertRaisesRegex( + edgedb.InterfaceError, r".*Use `with transaction:`" + ): for tx in self.client.transaction(): tx.execute("SELECT 123") diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 3ed2fc55..363af9c7 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -26,16 +26,15 @@ class TestSyncTx(tb.SyncQueryTestCase): - - SETUP = ''' + SETUP = """ CREATE TYPE test::TransactionTest EXTENDING std::Object { CREATE PROPERTY name -> std::str; }; - ''' + """ - TEARDOWN = ''' + TEARDOWN = """ DROP TYPE test::TransactionTest; - ''' + """ def test_sync_transaction_regular_01(self): tr = self.client.transaction() @@ -43,20 +42,24 @@ def test_sync_transaction_regular_01(self): with self.assertRaises(ZeroDivisionError): for with_tr in tr: with with_tr: - with_tr.execute(''' + with_tr.execute( + """ INSERT test::TransactionTest { name := 'Test Transaction' }; - ''') + """ + ) 1 / 0 - result = self.client.query(''' + result = self.client.query( + """ SELECT test::TransactionTest FILTER test::TransactionTest.name = 'Test Transaction'; - ''') + """ + ) self.assertEqual(result, []) @@ -82,7 +85,8 @@ async def test_sync_transaction_kinds(self): for tx in client.transaction(): with tx: tx.execute( - 'INSERT test::TransactionTest {name := "test"}') + 'INSERT test::TransactionTest {name := "test"}' + ) except edgedb.TransactionError: self.assertTrue(readonly) else: @@ -109,7 +113,7 @@ def test_sync_transaction_exclusive(self): with self.assertRaisesRegex( edgedb.InterfaceError, "concurrent queries within the same transaction " - "are not allowed" + "are not allowed", ): f1.result(timeout=5) f2.result(timeout=5) diff --git a/tools/gen_init.py b/tools/gen_init.py index bc54aa13..6de9d9c6 100644 --- a/tools/gen_init.py +++ b/tools/gen_init.py @@ -15,46 +15,45 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations import pathlib import re - -if __name__ == '__main__': +if __name__ == "__main__": this = pathlib.Path(__file__) - errors_fn = this.parent.parent / 'edgedb' / 'errors' / '__init__.py' - init_fn = this.parent.parent / 'edgedb' / '__init__.py' + errors_fn = this.parent.parent / "edgedb" / "errors" / "__init__.py" + init_fn = this.parent.parent / "edgedb" / "__init__.py" - with open(errors_fn, 'rt') as f: + with open(errors_fn) as f: errors_txt = f.read() - names = re.findall(r'^class\s+(?P\w+)', errors_txt, re.M) - names_list = '\n'.join(f' {name},' for name in names) - all_list = '\n'.join(f' "{name}",' for name in names) + names = re.findall(r"^class\s+(?P\w+)", errors_txt, re.M) + names_list = "\n".join(f" {name}," for name in names) + all_list = "\n".join(f' "{name}",' for name in names) code = ( - f'''from .errors import (\n{names_list}\n)\n''' - f'''\n__all__.extend([\n{all_list}\n])\n''' + f"""from .errors import (\n{names_list}\n)\n""" + f"""\n__all__.extend([\n{all_list}\n])\n""" ).splitlines() - with open(init_fn, 'rt') as f: + with open(init_fn) as f: lines = f.read().splitlines() start = end = -1 for no, line in enumerate(lines): - if line.startswith('# '): + if line.startswith("# "): start = no - elif line.startswith('# '): + elif line.startswith("# "): end = no if start == -1: - raise RuntimeError('could not find the tag') + raise RuntimeError("could not find the tag") if end == -1: - raise RuntimeError('could not find the tag') + raise RuntimeError("could not find the tag") - lines[start + 1:end] = code + lines[start + 1 : end] = code - with open(init_fn, 'w') as f: - f.write('\n'.join(lines)) - f.write('\n') + with open(init_fn, "w") as f: + f.write("\n".join(lines)) + f.write("\n") From 42790246ad8107b8471a6db3613cbe8ee9ecd572 Mon Sep 17 00:00:00 2001 From: adehad <26027314+adehad@users.noreply.github.com> Date: Sat, 1 Jul 2023 23:21:25 +0100 Subject: [PATCH 4/6] fix codespell errors --- .pre-commit-config.yaml | 2 +- edgedb/credentials.py | 6 +++--- tests/test_async_query.py | 18 +++++++++--------- tests/test_sourcecode.py | 2 +- tests/test_sync_query.py | 18 +++++++++--------- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ae6daca..b02180ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: entry: codespell language: python types: [text] - args: [] + args: [-L, "commitish,"] require_serial: false additional_dependencies: [] - repo: https://github.com/adrienverge/yamllint diff --git a/edgedb/credentials.py b/edgedb/credentials.py index d6ebe40c..5bd1e6c1 100644 --- a/edgedb/credentials.py +++ b/edgedb/credentials.py @@ -98,13 +98,13 @@ def validate_credentials(data: dict) -> Credentials: raise ValueError("`tls_security` must be a string") result["tls_security"] = tls_security - missmatch = ValueError( + mismatch = ValueError( f"tls_verify_hostname={verify} and " f"tls_security={tls_security} are incompatible" ) if tls_security == "strict" and verify is False: - raise missmatch + raise mismatch if tls_security in {"no_host_verification", "insecure"} and verify is True: - raise missmatch + raise mismatch return result diff --git a/tests/test_async_query.py b/tests/test_async_query.py index ceca85bd..8b20d156 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -593,26 +593,26 @@ async def test_async_args_uuid_pack(self): # Test that the custom UUID that our driver uses can be # passed back as a parameter. - ot = await self.client.query_single( + obj2 = await self.client.query_single( "select schema::Object {name} filter .id=$id", id=obj.id ) - self.assertEqual(obj.id, ot.id) - self.assertEqual(obj.name, ot.name) + self.assertEqual(obj.id, obj2.id) + self.assertEqual(obj.name, obj2.name) # Test that a string UUID is acceptable. - ot = await self.client.query_single( + obj2 = await self.client.query_single( "select schema::Object {name} filter .id=$id", id=str(obj.id) ) - self.assertEqual(obj.id, ot.id) - self.assertEqual(obj.name, ot.name) + self.assertEqual(obj.id, obj2.id) + self.assertEqual(obj.name, obj2.name) # Test that a standard uuid.UUID is acceptable. - ot = await self.client.query_single( + obj2 = await self.client.query_single( "select schema::Object {name} filter .id=$id", id=uuid.UUID(bytes=obj.id.bytes), ) - self.assertEqual(obj.id, ot.id) - self.assertEqual(obj.name, ot.name) + self.assertEqual(obj.id, obj2.id) + self.assertEqual(obj.name, obj2.name) with self.assertRaisesRegex( edgedb.InvalidArgumentError, "invalid UUID.*length must be" diff --git a/tests/test_sourcecode.py b/tests/test_sourcecode.py index 12bd0a9c..51a94c1c 100644 --- a/tests/test_sourcecode.py +++ b/tests/test_sourcecode.py @@ -37,7 +37,7 @@ def test_flake8(self): try: import flake8 # NoQA except ImportError: - raise unittest.SkipTest("flake8 moudule is missing") + raise unittest.SkipTest("flake8 module is missing") for subdir in ["edgedb", "tests"]: # ignore any top-level test files try: diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py index b20a16fe..cd5e07e1 100644 --- a/tests/test_sync_query.py +++ b/tests/test_sync_query.py @@ -479,26 +479,26 @@ def test_sync_args_uuid_pack(self): # Test that the custom UUID that our driver uses can be # passed back as a parameter. - ot = self.client.query_single( + obj2 = self.client.query_single( "select schema::Object {name} filter .id=$id", id=obj.id ) - self.assertEqual(obj.id, ot.id) - self.assertEqual(obj.name, ot.name) + self.assertEqual(obj.id, obj2.id) + self.assertEqual(obj.name, obj2.name) # Test that a string UUID is acceptable. - ot = self.client.query_single( + obj2 = self.client.query_single( "select schema::Object {name} filter .id=$id", id=str(obj.id) ) - self.assertEqual(obj.id, ot.id) - self.assertEqual(obj.name, ot.name) + self.assertEqual(obj.id, obj2.id) + self.assertEqual(obj.name, obj2.name) # Test that a standard uuid.UUID is acceptable. - ot = self.client.query_single( + obj2 = self.client.query_single( "select schema::Object {name} filter .id=$id", id=uuid.UUID(bytes=obj.id.bytes), ) - self.assertEqual(obj.id, ot.id) - self.assertEqual(obj.name, ot.name) + self.assertEqual(obj.id, obj2.id) + self.assertEqual(obj.name, obj2.name) with self.assertRaisesRegex( edgedb.InvalidArgumentError, "invalid UUID.*length must be" From 28bdc7eef43ed7d203713e7c917a3f9872f31545 Mon Sep 17 00:00:00 2001 From: adehad <26027314+adehad@users.noreply.github.com> Date: Sat, 1 Jul 2023 23:25:55 +0100 Subject: [PATCH 5/6] fix yamllint errors --- .github/workflows/release.yml | 263 +++++++++++++++++----------------- .github/workflows/tests.yml | 103 ++++++------- .yamllint | 10 ++ 3 files changed, 194 insertions(+), 182 deletions(-) create mode 100644 .yamllint diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3e019fb6..f189412a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,3 +1,4 @@ +--- name: Release on: @@ -15,34 +16,34 @@ jobs: outputs: version: ${{ steps.checkver.outputs.version }} steps: - - name: Validate release PR - uses: edgedb/action-release/validate-pr@master - id: checkver - with: - github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} - require_team: Release Managers - require_approval: no - version_file: edgedb/_version.py - version_line_pattern: | - __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - - - name: Stop if not approved - if: steps.checkver.outputs.approved != 'true' - run: | - echo ::error::PR is not approved yet. - exit 1 - - - name: Store release version for later use - env: - VERSION: ${{ steps.checkver.outputs.version }} - run: | - mkdir -p dist/ - echo "${VERSION}" > dist/VERSION - - - uses: actions/upload-artifact@v3 - with: - name: dist - path: dist/ + - name: Validate release PR + uses: edgedb/action-release/validate-pr@master + id: checkver + with: + github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} + require_team: Release Managers + require_approval: no + version_file: edgedb/_version.py + version_line_pattern: | + __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) + + - name: Stop if not approved + if: steps.checkver.outputs.approved != 'true' + run: | + echo ::error::PR is not approved yet. + exit 1 + + - name: Store release version for later use + env: + VERSION: ${{ steps.checkver.outputs.version }} + run: | + mkdir -p dist/ + echo "${VERSION}" > dist/VERSION + + - uses: actions/upload-artifact@v3 + with: + name: dist + path: dist/ build-sdist: needs: validate-release-request @@ -52,23 +53,23 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 50 - submodules: true + - uses: actions/checkout@v3 + with: + fetch-depth: 50 + submodules: true - - name: Set up Python - uses: actions/setup-python@v2 + - name: Set up Python + uses: actions/setup-python@v2 - - name: Build source distribution - run: | - pip install -U setuptools wheel pip - python setup.py sdist + - name: Build source distribution + run: | + pip install -U setuptools wheel pip + python setup.py sdist - - uses: actions/upload-artifact@v3 - with: - name: dist - path: dist/*.tar.* + - uses: actions/upload-artifact@v3 + with: + name: dist + path: dist/*.tar.* build-wheels-matrix: needs: validate-release-request @@ -109,99 +110,99 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 50 - submodules: true - - - name: Setup WSL - if: ${{ matrix.os == 'windows-2019' }} - uses: vampire/setup-wsl@v2 - with: - wsl-shell-user: edgedb - additional-packages: - ca-certificates - curl - - - name: Set up QEMU - if: runner.os == 'Linux' - uses: docker/setup-qemu-action@v2 - - - name: Install EdgeDB - uses: edgedb/setup-edgedb@v1 - - - uses: pypa/cibuildwheel@v2.12.3 - with: + - uses: actions/checkout@v3 + with: + fetch-depth: 50 + submodules: true + + - name: Setup WSL + if: ${{ matrix.os == 'windows-2019' }} + uses: vampire/setup-wsl@v2 + with: + wsl-shell-user: edgedb + additional-packages: + ca-certificates + curl + + - name: Set up QEMU + if: runner.os == 'Linux' + uses: docker/setup-qemu-action@v2 + + - name: Install EdgeDB + uses: edgedb/setup-edgedb@v1 + + - uses: pypa/cibuildwheel@v2.12.3 + with: only: ${{ matrix.only }} - env: - CIBW_BUILD_VERBOSITY: 1 - CIBW_BEFORE_ALL_LINUX: > - .github/workflows/install-edgedb.sh - CIBW_TEST_EXTRAS: "test" - CIBW_TEST_COMMAND: > - python {project}/tests/__init__.py - CIBW_TEST_COMMAND_WINDOWS: > - python {project}\tests\__init__.py - CIBW_TEST_COMMAND_LINUX: > - PY=`which python` - && CODEGEN=`which edgedb-py` - && chmod -R go+rX "$(dirname $(dirname $(dirname $PY)))" - && su -l edgedb -c "EDGEDB_PYTHON_TEST_CODEGEN_CMD=$CODEGEN $PY {project}/tests/__init__.py" - - - uses: actions/upload-artifact@v3 - with: - name: dist - path: wheelhouse/*.whl + env: + CIBW_BUILD_VERBOSITY: 1 + CIBW_BEFORE_ALL_LINUX: > + .github/workflows/install-edgedb.sh + CIBW_TEST_EXTRAS: "test" + CIBW_TEST_COMMAND: > + python {project}/tests/__init__.py + CIBW_TEST_COMMAND_WINDOWS: > + python {project}\tests\__init__.py + CIBW_TEST_COMMAND_LINUX: > + PY=`which python` + && CODEGEN=`which edgedb-py` + && chmod -R go+rX "$(dirname $(dirname $(dirname $PY)))" + && su -l edgedb -c "EDGEDB_PYTHON_TEST_CODEGEN_CMD=$CODEGEN $PY {project}/tests/__init__.py" + + - uses: actions/upload-artifact@v3 + with: + name: dist + path: wheelhouse/*.whl publish: needs: [build-sdist, build-wheels] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 5 - submodules: false - - - uses: actions/download-artifact@v3 - with: - name: dist - path: dist/ - - - name: Extract Release Version - id: relver - run: | - set -e - echo ::set-output name=version::$(cat dist/VERSION) - rm dist/VERSION - - - name: Merge and tag the PR - uses: edgedb/action-release/merge@master - with: - github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} - ssh_key: ${{ secrets.RELEASE_BOT_SSH_KEY }} - gpg_key: ${{ secrets.RELEASE_BOT_GPG_KEY }} - gpg_key_id: "5C468778062D87BF!" - tag_name: v${{ steps.relver.outputs.version }} - - - name: Publish Github Release - uses: elprans/gh-action-create-release@master - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: v${{ steps.relver.outputs.version }} - release_name: v${{ steps.relver.outputs.version }} - target: ${{ github.event.pull_request.base.ref }} - body: ${{ github.event.pull_request.body }} - draft: true - - - run: | - ls -al dist/ - - - name: Upload to PyPI - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.PYPI_TOKEN }} - # password: ${{ secrets.TEST_PYPI_TOKEN }} - # repository_url: https://test.pypi.org/legacy/ + - uses: actions/checkout@v3 + with: + fetch-depth: 5 + submodules: false + + - uses: actions/download-artifact@v3 + with: + name: dist + path: dist/ + + - name: Extract Release Version + id: relver + run: | + set -e + echo ::set-output name=version::$(cat dist/VERSION) + rm dist/VERSION + + - name: Merge and tag the PR + uses: edgedb/action-release/merge@master + with: + github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} + ssh_key: ${{ secrets.RELEASE_BOT_SSH_KEY }} + gpg_key: ${{ secrets.RELEASE_BOT_GPG_KEY }} + gpg_key_id: "5C468778062D87BF!" + tag_name: v${{ steps.relver.outputs.version }} + + - name: Publish Github Release + uses: elprans/gh-action-create-release@master + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: v${{ steps.relver.outputs.version }} + release_name: v${{ steps.relver.outputs.version }} + target: ${{ github.event.pull_request.base.ref }} + body: ${{ github.event.pull_request.body }} + draft: true + + - run: | + ls -al dist/ + + - name: Upload to PyPI + uses: pypa/gh-action-pypi-publish@master + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} + # password: ${{ secrets.TEST_PYPI_TOKEN }} + # repository_url: https://test.pypi.org/legacy/ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9c50a9ff..c116cf14 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,3 +1,4 @@ +--- name: Tests on: @@ -25,7 +26,7 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - edgedb-version: [stable , nightly] + edgedb-version: [stable, nightly] os: [ubuntu-latest, macos-latest, windows-2019] loop: [asyncio, uvloop] exclude: @@ -34,63 +35,63 @@ jobs: os: windows-2019 steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 50 - submodules: true + - uses: actions/checkout@v3 + with: + fetch-depth: 50 + submodules: true - - name: Check if release PR. - uses: edgedb/action-release/validate-pr@master - id: release - with: - github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} - missing_version_ok: yes - version_file: edgedb/_version.py - version_line_pattern: | - __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) + - name: Check if release PR. + uses: edgedb/action-release/validate-pr@master + id: release + with: + github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} + missing_version_ok: yes + version_file: edgedb/_version.py + version_line_pattern: | + __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - # If this is a release PR, skip tests. They will be run - # as part of the release process, and running them here - # might interfere with the release automation due to - # branch restrictions. + # If this is a release PR, skip tests. They will be run + # as part of the release process, and running them here + # might interfere with the release automation due to + # branch restrictions. - - name: Setup WSL - if: ${{ steps.release.outputs.version == 0 && matrix.os == 'windows-2019' }} - uses: vampire/setup-wsl@v1 - with: - wsl-shell-user: edgedb - additional-packages: - ca-certificates - curl + - name: Setup WSL + if: ${{ steps.release.outputs.version == 0 && matrix.os == 'windows-2019' }} + uses: vampire/setup-wsl@v1 + with: + wsl-shell-user: edgedb + additional-packages: + ca-certificates + curl - - name: Install EdgeDB - uses: edgedb/setup-edgedb@v1 - if: steps.release.outputs.version == 0 - with: - server-version: ${{ matrix.edgedb-version }} + - name: Install EdgeDB + uses: edgedb/setup-edgedb@v1 + if: steps.release.outputs.version == 0 + with: + server-version: ${{ matrix.edgedb-version }} - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - if: steps.release.outputs.version == 0 - with: - python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + if: steps.release.outputs.version == 0 + with: + python-version: ${{ matrix.python-version }} - - name: Install Python Deps - if: steps.release.outputs.version == 0 - run: | - python -m pip install --upgrade setuptools pip wheel - python -m pip install -e .[test] + - name: Install Python Deps + if: steps.release.outputs.version == 0 + run: | + python -m pip install --upgrade setuptools pip wheel + python -m pip install -e .[test] - - name: Test - if: steps.release.outputs.version == 0 - env: - LOOP_IMPL: ${{ matrix.loop }} - run: | - if [ "${LOOP_IMPL}" = "uvloop" ]; then - env USE_UVLOOP=1 python -m unittest -v tests.suite - else - python -m unittest -v tests.suite - fi + - name: Test + if: steps.release.outputs.version == 0 + env: + LOOP_IMPL: ${{ matrix.loop }} + run: | + if [ "${LOOP_IMPL}" = "uvloop" ]; then + env USE_UVLOOP=1 python -m unittest -v tests.suite + else + python -m unittest -v tests.suite + fi # This job exists solely to act as the test job aggregate to be # targeted by branch policies. diff --git a/.yamllint b/.yamllint new file mode 100644 index 00000000..a2819aeb --- /dev/null +++ b/.yamllint @@ -0,0 +1,10 @@ +--- +extends: default + +rules: + line-length: + max: 270 + allow-non-breakable-words: true + allow-non-breakable-inline-mappings: true + new-lines: disable + truthy: disable From 09721235f31b9de431e1b30a7c41f5be198a7367 Mon Sep 17 00:00:00 2001 From: adehad <26027314+adehad@users.noreply.github.com> Date: Sat, 1 Jul 2023 23:39:41 +0100 Subject: [PATCH 6/6] fix E402: Module level import not at top of file --- edgedb/errors/_base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/edgedb/errors/_base.py b/edgedb/errors/_base.py index 6fe67ce2..6423d5c0 100644 --- a/edgedb/errors/_base.py +++ b/edgedb/errors/_base.py @@ -23,6 +23,8 @@ import unicodedata import warnings +from edgedb.color import get_color + __all__ = ( "EdgeDBError", "EdgeDBMessage", @@ -331,6 +333,3 @@ def _unicode_width(text): "EDGEDB_ERROR_HINT can only be one of: default, enabled or disabled" ) SHOW_HINT = False - - -from edgedb.color import get_color