diff --git a/.gitignore b/.gitignore index 7b065ff5fcf3..0c865b80688d 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,13 @@ tests/.coverage* build/ tests/report/ tests/screenshots/ + +.direnv + +*.sqlite3 +passed.tests + +.coverage +.envrc +uv.lock +*.cobp diff --git a/.libcst.codemod.yaml b/.libcst.codemod.yaml new file mode 100644 index 000000000000..0d4a822fddd0 --- /dev/null +++ b/.libcst.codemod.yaml @@ -0,0 +1,17 @@ +# String that LibCST should look for in code which indicates that the +# module is generated code. +generated_code_marker: '@generated' +# Command line and arguments for invoking a code formatter. Anything +# specified here must be capable of taking code via stdin and returning +# formatted code via stdout. +formatter: ['black', '-'] +# List of regex patterns which LibCST will evaluate against filenames to +# determine if the module should be touched. +blacklist_patterns: [] +# List of modules that contain codemods inside of them. +modules: +- 'django.utils.codegen' +# Absolute or relative path of the repository root, used for providing +# full-repo metadata. Relative paths should be specified with this file +# location as the base. +repo_root: '.' diff --git a/django/db/__init__.py b/django/db/__init__.py index aa7d02d0f144..12bc84be165c 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -1,7 +1,12 @@ +from contextlib import contextmanager +import os +from asgiref.local import Local + from django.core import signals from django.db.utils import ( DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, + AsyncConnectionHandler, ConnectionHandler, ConnectionRouter, DatabaseError, @@ -36,6 +41,124 @@ ] connections = ConnectionHandler() +async_connections = AsyncConnectionHandler() + +new_connection_block_depth = Local() +new_connection_block_depth.value = 0 + + +def modify_cxn_depth(f): + try: + existing_value = new_connection_block_depth.value + except AttributeError: + existing_value = 0 + new_connection_block_depth.value = f(existing_value) + + +def should_use_sync_fallback(async_variant): + return async_variant and (new_connection_block_depth.value == 0) + + +commit_allowed = Local() +commit_allowed.value = False + +from contextlib import contextmanager + + +@contextmanager +def set_async_db_commit_permission(perm): + old_value = getattr(commit_allowed, "value", True) + commit_allowed.value = perm + try: + yield + finally: + commit_allowed.value = old_value + + +@contextmanager +def allow_async_db_commits(): + with set_async_db_commit_permission(True): + yield + + +@contextmanager +def block_async_db_commits(): + with set_async_db_commit_permission(False): + yield + + +def is_commit_allowed(): + try: + return commit_allowed.value + except: + # XXX making sure its set + commit_allowed.value = True + return True + + +class new_connection: + """ + Asynchronous context manager to instantiate new async connections. + + """ + + def __init__(self, using=DEFAULT_DB_ALIAS, force_rollback=False): + self.using = using + if not force_rollback and not is_commit_allowed(): + # this is for just figuring everything out + raise ValueError( + "Commits are currently blocked, use allow_async_db_commits to unblock" + ) + self.force_rollback = force_rollback + + def __enter__(self): + # XXX I need to fix up the codegen, for now this is going to no-op + if self.force_rollback: + # XXX IN TEST CONTEXT! + return + else: + raise NotSupportedError("new_connection doesn't support a sync context") + + def __exit__(self, exc_type, exc_value, traceback): + # XXX another thing to remove + return + + async def __aenter__(self): + # XXX stupid nonsense + modify_cxn_depth(lambda v: v + 1) + conn = connections.create_connection(self.using) + if conn.supports_async is False: + raise NotSupportedError( + "The database backend does not support asynchronous execution." + ) + + if conn.in_atomic_block: + raise NotSupportedError( + "Can't open an async connection while inside of a synchronous transaction block" + ) + self.conn = conn + + async_connections.add_connection(self.using, self.conn) + + await self.conn.aensure_connection() + if self.force_rollback is True: + await self.conn.aset_autocommit(False) + + return self.conn + + async def __aexit__(self, exc_type, exc_value, traceback): + # silly nonsense (again) + modify_cxn_depth(lambda v: v - 1) + autocommit = await self.conn.aget_autocommit() + if autocommit is False: + if exc_type is None and self.force_rollback is False: + await self.conn.acommit() + else: + await self.conn.arollback() + await self.conn.aclose() + + async_connections.pop_connection(self.using) + router = ConnectionRouter() diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index e6e0325d07bd..658ad1810f6e 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -2,12 +2,13 @@ import copy import datetime import logging +import os import threading import time import warnings import zoneinfo from collections import deque -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from django.conf import settings from django.core.exceptions import ImproperlyConfigured @@ -20,6 +21,11 @@ from django.db.utils import DatabaseErrorWrapper, ProgrammingError from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property +from django.utils.codegen import ( + from_codegen, + generate_unasynced, + ASYNC_TRUTH_MARKER, +) NO_DB_ALIAS = "__no_db__" RAN_DB_VERSION_CHECK = set() @@ -39,6 +45,8 @@ class BaseDatabaseWrapper: ops = None vendor = "unknown" display_name = "unknown" + supports_async = False + SchemaEditorClass = None # Classes instantiated in __init__(). client_class = None @@ -47,6 +55,7 @@ class BaseDatabaseWrapper: introspection_class = None ops_class = None validation_class = BaseDatabaseValidation + _aconnection_pools = {} queries_limit = 9000 @@ -54,6 +63,7 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): # Connection related attributes. # The underlying database connection. self.connection = None + self.aconnection = None # `settings_dict` should be a dictionary containing keys such as # NAME, USER, etc. It's called `settings_dict` instead of `settings` # to disambiguate it from Django settings modules. @@ -187,22 +197,41 @@ def get_database_version(self): "method." ) - def check_database_version_supported(self): - """ - Raise an error if the database version isn't supported by this - version of Django. - """ + async def aget_database_version(self): + """Return a tuple of the database's version.""" + raise NotSupportedError( + "subclasses of BaseDatabaseWrapper may require an aget_database_version() " + "method." + ) + + def _validate_database_version_supported(self, db_version): if ( self.features.minimum_database_version is not None - and self.get_database_version() < self.features.minimum_database_version + and db_version < self.features.minimum_database_version ): - db_version = ".".join(map(str, self.get_database_version())) + str_db_version = ".".join(map(str, db_version)) min_db_version = ".".join(map(str, self.features.minimum_database_version)) raise NotSupportedError( f"{self.display_name} {min_db_version} or later is required " - f"(found {db_version})." + f"(found {str_db_version})." ) + def check_database_version_supported(self): + """ + Raise an error if the database version isn't supported by this + version of Django. + """ + db_version = self.get_database_version() + self._validate_database_version_supported(db_version) + + async def acheck_database_version_supported(self): + """ + Raise an error if the database version isn't supported by this + version of Django. + """ + db_version = await self.aget_database_version() + self._validate_database_version_supported(db_version) + # ##### Backend-specific methods for creating connections and cursors ##### def get_connection_params(self): @@ -219,6 +248,14 @@ def get_new_connection(self, conn_params): "method" ) + async def aget_new_connection(self, conn_params): + """Open a connection to the database.""" + raise NotSupportedError( + "subclasses of BaseDatabaseWrapper may require an aget_new_connection() " + "method" + ) + + @from_codegen def init_connection_state(self): """Initialize the database connection settings.""" global RAN_DB_VERSION_CHECK @@ -226,18 +263,30 @@ def init_connection_state(self): self.check_database_version_supported() RAN_DB_VERSION_CHECK.add(self.alias) + @generate_unasynced() + async def ainit_connection_state(self): + """Initialize the database connection settings.""" + global RAN_DB_VERSION_CHECK + if self.alias not in RAN_DB_VERSION_CHECK: + await self.acheck_database_version_supported() + RAN_DB_VERSION_CHECK.add(self.alias) + def create_cursor(self, name=None): """Create a cursor. Assume that a connection is established.""" raise NotImplementedError( "subclasses of BaseDatabaseWrapper may require a create_cursor() method" ) + def create_async_cursor(self, name=None): + """Create a cursor. Assume that a connection is established.""" + raise NotSupportedError( + "subclasses of BaseDatabaseWrapper may require a " + "create_async_cursor() method" + ) + # ##### Backend-specific methods for creating connections ##### - @async_unsafe - def connect(self): - """Connect to the database. Assume that the connection is closed.""" - # Check for invalid configurations. + def _pre_connect(self): self.check_settings() # In case the previous connection was closed while in an atomic block self.in_atomic_block = False @@ -252,6 +301,13 @@ def connect(self): self.errors_occurred = False # New connections are healthy. self.health_check_done = True + + @from_codegen + @async_unsafe + def connect(self): + """Connect to the database. Assume that the connection is closed.""" + # Check for invalid configurations. + self._pre_connect() # Establish the connection conn_params = self.get_connection_params() self.connection = self.get_new_connection(conn_params) @@ -261,6 +317,24 @@ def connect(self): self.run_on_commit = [] + @generate_unasynced(async_unsafe=True) + async def aconnect(self): + """Connect to the database. Assume that the connection is closed.""" + # Check for invalid configurations. + self._pre_connect() + if ASYNC_TRUTH_MARKER: + # Establish the connection + conn_params = self.get_connection_params(for_async=True) + else: + # Establish the connection + conn_params = self.get_connection_params() + self.aconnection = await self.aget_new_connection(conn_params) + await self.aset_autocommit(self.settings_dict["AUTOCOMMIT"]) + await self.ainit_connection_state() + await connection_created.asend(sender=self.__class__, connection=self) + + self.run_on_commit = [] + def check_settings(self): if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ: raise ImproperlyConfigured( @@ -268,6 +342,7 @@ def check_settings(self): % self.alias ) + @from_codegen @async_unsafe def ensure_connection(self): """Guarantee that a connection to the database is established.""" @@ -279,6 +354,17 @@ def ensure_connection(self): with self.wrap_database_errors: self.connect() + @generate_unasynced(async_unsafe=True) + async def aensure_connection(self): + """Guarantee that a connection to the database is established.""" + if self.aconnection is None: + if self.in_atomic_block and self.closed_in_transaction: + raise ProgrammingError( + "Cannot open a new connection in an atomic block." + ) + with self.wrap_database_errors: + await self.aconnect() + # ##### Backend-specific wrappers for PEP-249 connection methods ##### def _prepare_cursor(self, cursor): @@ -292,27 +378,65 @@ def _prepare_cursor(self, cursor): wrapped_cursor = self.make_cursor(cursor) return wrapped_cursor + def _aprepare_cursor(self, cursor) -> utils.AsyncCursorWrapper: + """ + Validate the connection is usable and perform database cursor wrapping. + """ + + self.validate_thread_sharing() + if self.queries_logged: + wrapped_cursor = self.make_debug_async_cursor(cursor) + else: + wrapped_cursor = self.make_async_cursor(cursor) + return wrapped_cursor + def _cursor(self, name=None): self.close_if_health_check_failed() self.ensure_connection() with self.wrap_database_errors: return self._prepare_cursor(self.create_cursor(name)) + def _acursor(self, name=None) -> utils.AsyncCursorCtx: + return utils.AsyncCursorCtx(self, name) + + @from_codegen def _commit(self): if self.connection is not None: with debug_transaction(self, "COMMIT"), self.wrap_database_errors: return self.connection.commit() + @generate_unasynced() + async def _acommit(self): + if self.aconnection is not None: + with debug_transaction(self, "COMMIT"), self.wrap_database_errors: + return await self.aconnection.commit() + + @from_codegen def _rollback(self): if self.connection is not None: with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors: return self.connection.rollback() + @generate_unasynced() + async def _arollback(self): + if self.aconnection is not None: + with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors: + return await self.aconnection.rollback() + + @from_codegen def _close(self): + print(f"YYY {id(self)} BDW CLOSE") if self.connection is not None: with self.wrap_database_errors: return self.connection.close() + @generate_unasynced() + async def _aclose(self): + print(f"YYY {id(self)} BDW CLOSE") + if self.aconnection is not None: + with self.wrap_database_errors: + return await self.aconnection.close() + # ##### Generic wrappers for PEP-249 connection methods ##### @async_unsafe @@ -320,6 +444,13 @@ def cursor(self): """Create a cursor, opening a connection if necessary.""" return self._cursor() + def acursor(self) -> utils.AsyncCursorCtx: + """Create an async cursor, opening a connection if necessary.""" + # if ASYNC_TRUTH_MARKER: + # self.validate_no_atomic_block() + return self._acursor() + + @from_codegen @async_unsafe def commit(self): """Commit a transaction and reset the dirty flag.""" @@ -330,6 +461,17 @@ def commit(self): self.errors_occurred = False self.run_commit_hooks_on_set_autocommit_on = True + @generate_unasynced(async_unsafe=True) + async def acommit(self): + """Commit a transaction and reset the dirty flag.""" + self.validate_thread_sharing() + self.validate_no_atomic_block() + await self._acommit() + # A successful commit means that the database connection works. + self.errors_occurred = False + self.run_commit_hooks_on_set_autocommit_on = True + + @from_codegen @async_unsafe def rollback(self): """Roll back a transaction and reset the dirty flag.""" @@ -341,6 +483,18 @@ def rollback(self): self.needs_rollback = False self.run_on_commit = [] + @generate_unasynced(async_unsafe=True) + async def arollback(self): + """Roll back a transaction and reset the dirty flag.""" + self.validate_thread_sharing() + self.validate_no_atomic_block() + await self._arollback() + # A successful rollback means that the database connection works. + self.errors_occurred = False + self.needs_rollback = False + self.run_on_commit = [] + + @from_codegen @async_unsafe def close(self): """Close the connection to the database.""" @@ -361,24 +515,60 @@ def close(self): else: self.connection = None + @generate_unasynced(async_unsafe=True) + async def aclose(self): + """Close the connection to the database.""" + self.validate_thread_sharing() + self.run_on_commit = [] + + # Don't call validate_no_atomic_block() to avoid making it difficult + # to get rid of a connection in an invalid state. The next connect() + # will reset the transaction state anyway. + if self.closed_in_transaction or self.aconnection is None: + return + try: + await self._aclose() + finally: + if self.in_atomic_block: + self.closed_in_transaction = True + self.needs_rollback = True + else: + self.aconnection = None + # ##### Backend-specific savepoint management methods ##### def _savepoint(self, sid): with self.cursor() as cursor: cursor.execute(self.ops.savepoint_create_sql(sid)) + async def _asavepoint(self, sid): + async with self.acursor() as cursor: + await cursor.aexecute(self.ops.savepoint_create_sql(sid)) + def _savepoint_rollback(self, sid): with self.cursor() as cursor: cursor.execute(self.ops.savepoint_rollback_sql(sid)) + async def _asavepoint_rollback(self, sid): + async with self.acursor() as cursor: + await cursor.aexecute(self.ops.savepoint_rollback_sql(sid)) + def _savepoint_commit(self, sid): with self.cursor() as cursor: cursor.execute(self.ops.savepoint_commit_sql(sid)) + async def _asavepoint_commit(self, sid): + async with self.acursor() as cursor: + await cursor.aexecute(self.ops.savepoint_commit_sql(sid)) + def _savepoint_allowed(self): # Savepoints cannot be created outside a transaction return self.features.uses_savepoints and not self.get_autocommit() + async def _asavepoint_allowed(self): + # Savepoints cannot be created outside a transaction + return self.features.uses_savepoints and not (await self.aget_autocommit()) + # ##### Generic savepoint management methods ##### @async_unsafe @@ -402,6 +592,26 @@ def savepoint(self): return sid + async def asavepoint(self): + """ + Create a savepoint inside the current transaction. Return an + identifier for the savepoint that will be used for the subsequent + rollback or commit. Do nothing if savepoints are not supported. + """ + if not (await self._asavepoint_allowed()): + return + + thread_ident = _thread.get_ident() + tid = str(thread_ident).replace("-", "") + + self.savepoint_state += 1 + sid = "s%s_x%d" % (tid, self.savepoint_state) + + self.validate_thread_sharing() + await self._asavepoint(sid) + + return sid + @async_unsafe def savepoint_rollback(self, sid): """ @@ -420,6 +630,23 @@ def savepoint_rollback(self, sid): if sid not in sids ] + async def asavepoint_rollback(self, sid): + """ + Roll back to a savepoint. Do nothing if savepoints are not supported. + """ + if not (await self._asavepoint_allowed()): + return + + self.validate_thread_sharing() + await self._asavepoint_rollback(sid) + + # Remove any callbacks registered while this savepoint was active. + self.run_on_commit = [ + (sids, func, robust) + for (sids, func, robust) in self.run_on_commit + if sid not in sids + ] + @async_unsafe def savepoint_commit(self, sid): """ @@ -431,6 +658,16 @@ def savepoint_commit(self, sid): self.validate_thread_sharing() self._savepoint_commit(sid) + async def asavepoint_commit(self, sid): + """ + Release a savepoint. Do nothing if savepoints are not supported. + """ + if not (await self._asavepoint_allowed()): + return + + self.validate_thread_sharing() + await self._asavepoint_commit(sid) + @async_unsafe def clean_savepoints(self): """ @@ -448,11 +685,26 @@ def _set_autocommit(self, autocommit): "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method" ) + async def _aset_autocommit(self, autocommit): + """ + Backend-specific implementation to enable or disable autocommit. + """ + raise NotSupportedError( + "subclasses of BaseDatabaseWrapper may require an _aset_autocommit() method" + ) + # ##### Generic transaction management methods ##### def get_autocommit(self): """Get the autocommit state.""" self.ensure_connection() + # print(f"get_autocommit() <- {self.autocommit}") + return self.autocommit + + async def aget_autocommit(self): + """Get the autocommit state.""" + await self.aensure_connection() + # print(f"aget_autocommit() <- {self.autocommit}") return self.autocommit def set_autocommit( @@ -492,6 +744,46 @@ def set_autocommit( self.run_and_clear_commit_hooks() self.run_commit_hooks_on_set_autocommit_on = False + async def aset_autocommit( + self, autocommit, force_begin_transaction_with_broken_autocommit=False + ): + """ + Enable or disable autocommit. + + The usual way to start a transaction is to turn autocommit off. + SQLite does not properly start a transaction when disabling + autocommit. To avoid this buggy behavior and to actually enter a new + transaction, an explicit BEGIN is required. Using + force_begin_transaction_with_broken_autocommit=True will issue an + explicit BEGIN with SQLite. This option will be ignored for other + backends. + """ + # print(f"{id(self)}.aset_autocommit({autocommit})") + # if autocommit is False: + # raise ValueError("FALSE") + self.validate_no_atomic_block() + await self.aclose_if_health_check_failed() + await self.aensure_connection() + + start_transaction_under_autocommit = ( + force_begin_transaction_with_broken_autocommit + and not autocommit + and hasattr(self, "_astart_transaction_under_autocommit") + ) + + if start_transaction_under_autocommit: + await self._astart_transaction_under_autocommit() + elif autocommit: + await self._aset_autocommit(autocommit) + else: + with debug_transaction(self, "BEGIN"): + await self._aset_autocommit(autocommit) + self.autocommit = autocommit + + if autocommit and self.run_commit_hooks_on_set_autocommit_on: + self.run_and_clear_commit_hooks() + self.run_commit_hooks_on_set_autocommit_on = False + def get_rollback(self): """Get the "needs rollback" flag -- for *advanced use* only.""" if not self.in_atomic_block: @@ -589,6 +881,20 @@ def close_if_health_check_failed(self): self.close() self.health_check_done = True + async def aclose_if_health_check_failed(self): + """Close existing connection if it fails a health check.""" + if ( + self.aconnection is None + or not self.health_check_enabled + or self.health_check_done + ): + return + + is_usable = await self.ais_usable() + if not is_usable: + await self.aclose() + self.health_check_done = True + def close_if_unusable_or_obsolete(self): """ Close the current connection if unrecoverable errors have occurred @@ -678,10 +984,18 @@ def make_debug_cursor(self, cursor): """Create a cursor that logs all queries in self.queries_log.""" return utils.CursorDebugWrapper(cursor, self) + def make_debug_async_cursor(self, cursor): + """Create a cursor that logs all queries in self.queries_log.""" + return utils.AsyncCursorDebugWrapper(cursor, self) + def make_cursor(self, cursor): """Create a cursor without debug logging.""" return utils.CursorWrapper(cursor, self) + def make_async_cursor(self, cursor): + """Create a cursor without debug logging.""" + return utils.AsyncCursorWrapper(cursor, self) + @contextmanager def temporary_connection(self): """ @@ -699,6 +1013,25 @@ def temporary_connection(self): if must_close: self.close() + @asynccontextmanager + async def atemporary_connection(self): + """ + Context manager that ensures that a connection is established, and + if it opened one, closes it to avoid leaving a dangling connection. + This is useful for operations outside of the request-response cycle. + + Provide a cursor: async with self.atemporary_connection() as cursor: ... + """ + # unused + + must_close = self.aconnection is None + try: + async with self.acursor() as cursor: + yield cursor + finally: + if must_close: + await self.aclose() + @contextmanager def _nodb_cursor(self): """ diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 5d1f260edfc7..65dd45e477d6 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -9,6 +9,7 @@ from django.db import NotSupportedError, transaction from django.db.models.expressions import Col from django.utils import timezone +from django.utils.codegen import from_codegen, generate_unasynced from django.utils.encoding import force_str @@ -205,6 +206,7 @@ def distinct_sql(self, fields, params): else: return ["DISTINCT"], [] + @from_codegen def fetch_returned_insert_columns(self, cursor, returning_params): """ Given a cursor object that has just performed an INSERT...RETURNING @@ -212,6 +214,14 @@ def fetch_returned_insert_columns(self, cursor, returning_params): """ return cursor.fetchone() + @generate_unasynced() + async def afetch_returned_insert_columns(self, cursor, returning_params): + """ + Given a cursor object that has just performed an INSERT...RETURNING + statement into a table, return the newly created data. + """ + return await cursor.afetchone() + def force_group_by(self): """ Return a GROUP BY clause to use with a HAVING clause when no grouping diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index c864cab57a2e..6c9ba2e593bb 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -5,7 +5,10 @@ """ import asyncio +import inspect +import os import threading +import traceback import warnings from contextlib import contextmanager @@ -14,12 +17,17 @@ from django.db import DatabaseError as WrappedDatabaseError from django.db import connections from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper +from django.db.backends.utils import ( + AsyncCursorDebugWrapper as AsyncBaseCursorDebugWrapper, +) from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property from django.utils.safestring import SafeString from django.utils.version import get_version_tuple +LOG_CREATIONS = False + try: try: import psycopg as Database @@ -86,9 +94,16 @@ def _get_varchar_column(data): return "varchar(%(max_length)s)" % data +# HACK additions to make OTel instrumentation work properly +Database.AsyncConnection.pq = Database.pq +Database.Connection.pq = Database.pq + + class DatabaseWrapper(BaseDatabaseWrapper): vendor = "postgresql" display_name = "PostgreSQL" + supports_async = is_psycopg3 + # This dictionary maps Field objects to their associated PostgreSQL column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. @@ -181,6 +196,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): _named_cursor_idx = 0 _connection_pools = {} + def __init__(self, *args, **kwargs): + self._creation_stack = "\n".join(traceback.format_stack()) + super().__init__(*args, **kwargs) + @property def pool(self): pool_options = self.settings_dict["OPTIONS"].get("pool") @@ -222,11 +241,57 @@ def pool(self): return self._connection_pools[self.alias] + @property + def apool(self): + pool_options = self.settings_dict["OPTIONS"].get("pool") + if self.alias == NO_DB_ALIAS or not pool_options: + return None + + if self.alias not in self._aconnection_pools: + if self.settings_dict.get("CONN_MAX_AGE", 0) != 0: + raise ImproperlyConfigured( + "Pooling doesn't support persistent connections." + ) + # Set the default options. + if pool_options is True: + pool_options = {} + + try: + from psycopg_pool import AsyncConnectionPool + except ImportError as err: + raise ImproperlyConfigured( + "Error loading psycopg_pool module.\nDid you install psycopg[pool]?" + ) from err + + connect_kwargs = self.get_connection_params(for_async=True) + # Ensure we run in autocommit, Django properly sets it later on. + connect_kwargs["autocommit"] = True + enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"] + pool = AsyncConnectionPool( + kwargs=connect_kwargs, + open=False, # Do not open the pool during startup. + configure=self._aconfigure_connection, + check=AsyncConnectionPool.check_connection if enable_checks else None, + **pool_options, + ) + # setdefault() ensures that multiple threads don't set this in + # parallel. Since we do not open the pool during it's init above, + # this means that at worst during startup multiple threads generate + # pool objects and the first to set it wins. + self._aconnection_pools.setdefault(self.alias, pool) + + return self._aconnection_pools[self.alias] + def close_pool(self): if self.pool: self.pool.close() del self._connection_pools[self.alias] + async def aclose_pool(self): + if self.apool: + await self.apool.close() + del self._aconnection_pools[self.alias] + def get_database_version(self): """ Return a tuple of the database's version. @@ -234,7 +299,38 @@ def get_database_version(self): """ return divmod(self.pg_version, 10000) - def get_connection_params(self): + async def aget_database_version(self): + """ + Return a tuple of the database's version. + E.g. for pg_version 120004, return (12, 4). + """ + pg_version = await self.apg_version + return divmod(pg_version, 10000) + + def _get_sync_cursor_factory(self, server_side_binding=None): + if is_psycopg3 and server_side_binding is True: + return ServerBindingCursor + else: + return Cursor + + def _get_async_cursor_factory(self, server_side_binding=None): + if is_psycopg3 and server_side_binding is True: + return AsyncServerBindingCursor + else: + return AsyncCursor + + def _get_cursor_factory(self, server_side_binding=None, for_async=False): + if for_async and not is_psycopg3: + raise ImproperlyConfigured( + "Django requires psycopg >= 3 for ORM async support." + ) + + if for_async: + return self._get_async_cursor_factory(server_side_binding) + else: + return self._get_sync_cursor_factory(server_side_binding) + + def get_connection_params(self, for_async=False): settings_dict = self.settings_dict # None may be used to connect to the default 'postgres' db if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"): @@ -274,14 +370,10 @@ def get_connection_params(self): raise ImproperlyConfigured("Database pooling requires psycopg >= 3") server_side_binding = conn_params.pop("server_side_binding", None) - conn_params.setdefault( - "cursor_factory", - ( - ServerBindingCursor - if is_psycopg3 and server_side_binding is True - else Cursor - ), + cursor_factory = self._get_cursor_factory( + server_side_binding, for_async=for_async ) + conn_params.setdefault("cursor_factory", cursor_factory) if settings_dict["USER"]: conn_params["user"] = settings_dict["USER"] if settings_dict["PASSWORD"]: @@ -301,8 +393,7 @@ def get_connection_params(self): ) return conn_params - @async_unsafe - def get_new_connection(self, conn_params): + def _get_isolation_level(self): # self.isolation_level must be set: # - after connecting to the database in order to obtain the database's # default when no value is explicitly specified in options. @@ -313,25 +404,30 @@ def get_new_connection(self, conn_params): try: isolation_level_value = options["isolation_level"] except KeyError: - self.isolation_level = IsolationLevel.READ_COMMITTED + isolation_level = IsolationLevel.READ_COMMITTED else: - # Set the isolation level to the value from OPTIONS. try: - self.isolation_level = IsolationLevel(isolation_level_value) + isolation_level = IsolationLevel(isolation_level_value) set_isolation_level = True except ValueError: raise ImproperlyConfigured( f"Invalid transaction isolation level {isolation_level_value} " f"specified. Use one of the psycopg.IsolationLevel values." ) + return isolation_level, set_isolation_level + + @async_unsafe + def get_new_connection(self, conn_params): + isolation_level, set_isolation_level = self._get_isolation_level() + self.isolation_level = isolation_level if self.pool: # If nothing else has opened the pool, open it now. self.pool.open() connection = self.pool.getconn() else: - connection = self.Database.connect(**conn_params) + connection = Database.Connection.connect(**conn_params) if set_isolation_level: - connection.isolation_level = self.isolation_level + connection.isolation_level = isolation_level if not is_psycopg3: # Register dummy loads() to avoid a round trip from psycopg2's # decode to json.dumps() to json.loads(), when using a custom @@ -341,6 +437,19 @@ def get_new_connection(self, conn_params): ) return connection + async def aget_new_connection(self, conn_params): + isolation_level, set_isolation_level = self._get_isolation_level() + self.isolation_level = isolation_level + if self.apool: + # If nothing else has opened the pool, open it now. + await self.apool.open() + connection = await self.apool.getconn() + else: + connection = await self.Database.AsyncConnection.connect(**conn_params) + if set_isolation_level: + connection.isolation_level = isolation_level + return connection + def ensure_timezone(self): # Close the pool so new connections pick up the correct timezone. self.close_pool() @@ -348,6 +457,13 @@ def ensure_timezone(self): return False return self._configure_timezone(self.connection) + async def aensure_timezone(self): + # Close the pool so new connections pick up the correct timezone. + await self.aclose_pool() + if self.connection is None: + return False + return await self._aconfigure_timezone(self.connection) + def _configure_timezone(self, connection): conn_timezone_name = connection.info.parameter_status("TimeZone") timezone_name = self.timezone_name @@ -357,6 +473,15 @@ def _configure_timezone(self, connection): return True return False + async def _aconfigure_timezone(self, connection): + conn_timezone_name = connection.info.parameter_status("TimeZone") + timezone_name = self.timezone_name + if timezone_name and conn_timezone_name != timezone_name: + async with connection.cursor() as cursor: + await cursor.execute(self.ops.set_time_zone_sql(), [timezone_name]) + return True + return False + def _configure_role(self, connection): if new_role := self.settings_dict["OPTIONS"].get("assume_role"): with connection.cursor() as cursor: @@ -365,6 +490,14 @@ def _configure_role(self, connection): return True return False + async def _aconfigure_role(self, connection): + if new_role := self.settings_dict["OPTIONS"].get("assume_role"): + async with connection.acursor() as cursor: + sql = self.ops.compose_sql("SET ROLE %s", [new_role]) + await cursor.aaexecute(sql) + return True + return False + def _configure_connection(self, connection): # This function is called from init_connection_state and from the # psycopg pool itself after a connection is opened. @@ -378,7 +511,22 @@ def _configure_connection(self, connection): return commit_role or commit_tz + async def _aconfigure_connection(self, connection): + # This function is called from init_connection_state and from the + # psycopg pool itself after a connection is opened. + + # Commit after setting the time zone. + commit_tz = await self._aconfigure_timezone(connection) + # Set the role on the connection. This is useful if the credential used + # to login is not the same as the role that owns database resources. As + # can be the case when using temporary or ephemeral credentials. + commit_role = await self._aconfigure_role(connection) + + return commit_role or commit_tz + def _close(self): + if "QL" in os.environ: + print(f"QQQ {id(self)} BDW CLOSE") if self.connection is not None: # `wrap_database_errors` only works for `putconn` as long as there # is no `reset` function set in the pool because it is deferred @@ -394,6 +542,24 @@ def _close(self): else: return self.connection.close() + async def _aclose(self): + if "QL" in os.environ: + print(f"QQQ {id(self)} BDW CLOSE") + if self.aconnection is not None: + # `wrap_database_errors` only works for `putconn` as long as there + # is no `reset` function set in the pool because it is deferred + # into a thread and not directly executed. + with self.wrap_database_errors: + if self.apool: + # Ensure the correct pool is returned. This is a workaround + # for tests so a pool can be changed on setting changes + # (e.g. USE_TZ, TIME_ZONE). + await self.aconnection._pool.putconn(self.aconnection) + # Connection can no longer be used. + self.aconnection = None + else: + return await self.aconnection.close() + def init_connection_state(self): super().init_connection_state() @@ -403,6 +569,17 @@ def init_connection_state(self): if commit and not self.get_autocommit(): self.connection.commit() + async def ainit_connection_state(self): + await super().ainit_connection_state() + + if self.aconnection is not None and not self.apool: + commit = await self._aconfigure_connection(self.aconnection) + + if commit: + autocommit = await self.aget_autocommit() + if not autocommit: + await self.aconnection.commit() + @async_unsafe def create_cursor(self, name=None): if name: @@ -438,6 +615,35 @@ def create_cursor(self, name=None): cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None return cursor + def create_async_cursor(self, name=None): + if name: + if self.settings_dict["OPTIONS"].get("server_side_binding") is not True: + # psycopg >= 3 forces the usage of server-side bindings for + # named cursors so a specialized class that implements + # server-side cursors while performing client-side bindings + # must be used if `server_side_binding` is disabled (default). + cursor = AsyncServerSideCursor( + self.aconnection, + name=name, + scrollable=False, + withhold=self.aconnection.autocommit, + ) + else: + # In autocommit mode, the cursor will be used outside of a + # transaction, hence use a holdable cursor. + cursor = self.aconnection.cursor( + name, scrollable=False, withhold=self.aconnection.autocommit + ) + else: + cursor = self.aconnection.cursor() + + # Register the cursor timezone only if the connection disagrees, to + # avoid copying the adapter map. + tzloader = self.aconnection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT) + if self.timezone != tzloader.timezone: + register_tzloader(self.timezone, cursor) + return cursor + def tzinfo_factory(self, offset): return self.timezone @@ -469,10 +675,37 @@ def chunked_cursor(self): ) ) + async def achunked_cursor(self): + self._named_cursor_idx += 1 + # Get the current async task + try: + current_task = asyncio.current_task() + except RuntimeError: + current_task = None + # Current task can be none even if the current_task call didn't error + if current_task: + task_ident = str(id(current_task)) + else: + task_ident = "sync" + # Use that and the thread ident to get a unique name + return self._acursor( + name="_django_curs_%d_%s_%d" + % ( + # Avoid reusing name in other threads / tasks + threading.current_thread().ident, + task_ident, + self._named_cursor_idx, + ) + ) + def _set_autocommit(self, autocommit): with self.wrap_database_errors: self.connection.autocommit = autocommit + async def _aset_autocommit(self, autocommit): + with self.wrap_database_errors: + await self.aconnection.set_autocommit(autocommit) + def check_constraints(self, table_names=None): """ Check constraints by setting them to immediate. Return them to deferred @@ -500,6 +733,12 @@ def close_if_health_check_failed(self): return return super().close_if_health_check_failed() + async def aclose_if_health_check_failed(self): + if self.apool: + # The pool only returns healthy connections. + return + return await super().aclose_if_health_check_failed() + @contextmanager def _nodb_cursor(self): cursor = None @@ -543,6 +782,11 @@ def pg_version(self): with self.temporary_connection(): return self.connection.info.server_version + @cached_property + async def apg_version(self): + async with self.atemporary_connection(): + return self.aconnection.info.server_version + def make_debug_cursor(self, cursor): return CursorDebugWrapper(cursor, self) @@ -598,6 +842,36 @@ def copy(self, statement): with self.debug_sql(statement): return self.cursor.copy(statement) + class AsyncServerBindingCursor(CursorMixin, Database.AsyncClientCursor): + pass + + class AsyncCursor(CursorMixin, Database.AsyncClientCursor): + pass + + class AsyncServerSideCursor( + CursorMixin, + Database.client_cursor.ClientCursorMixin, + Database.AsyncServerCursor, + ): + """ + psycopg >= 3 forces the usage of server-side bindings when using named + cursors but the ORM doesn't yet support the systematic generation of + prepareable SQL (#20516). + + ClientCursorMixin forces the usage of client-side bindings while + AsyncServerCursor implements the logic required to declare and scroll + through named cursors. + + Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to + specify how parameters should be bound instead, which AsyncServerCursor + would inherit, but that's not the case. + """ + + class AsyncCursorDebugWrapper(AsyncBaseCursorDebugWrapper): + def copy(self, statement): + with self.debug_sql(statement): + return self.cursor.copy(statement) + else: Cursor = psycopg2.extensions.cursor diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 9db755bb8919..67a3631473be 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -14,6 +14,7 @@ from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict from django.db.models.functions import Cast +from django.utils.codegen import from_codegen, generate_unasynced from django.utils.regex_helper import _lazy_re_compile @@ -155,6 +156,7 @@ def bulk_insert_sql(self, fields, placeholder_rows): return f"SELECT * FROM {placeholder_rows}" return super().bulk_insert_sql(fields, placeholder_rows) + @from_codegen def fetch_returned_insert_rows(self, cursor): """ Given a cursor object that has just performed an INSERT...RETURNING @@ -162,6 +164,14 @@ def fetch_returned_insert_rows(self, cursor): """ return cursor.fetchall() + @generate_unasynced() + async def afetch_returned_insert_rows(self, cursor): + """ + Given a cursor object that has just performed an INSERT...RETURNING + statement into a table, return the tuple of returned data. + """ + return await cursor.fetchall() + def lookup_cast(self, lookup_type, internal_type=None): lookup = "%s" # Cast text lookups to text to allow things like filter(x__contains=4) diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 568f510a670e..76b0c6d5d9a2 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -11,8 +11,55 @@ from django.db import NotSupportedError from django.utils.dateparse import parse_time +from asgiref.local import Local + logger = logging.getLogger("django.db.backends") +# XXX experimentation +sync_cursor_ops_local = Local() +sync_cursor_ops_local.value = False + + +# XXX experimentation +class sync_cursor_ops_blocked: + @classmethod + def get(cls): + # This is extremely wrong! Maybe. To think about + try: + return sync_cursor_ops_local.value + except AttributeError: + # if it's not set... it's not True + sync_cursor_ops_local.value = False + return False + + @classmethod + def set(cls, v): + sync_cursor_ops_local.value = v + + +# XXX experimentation +@contextmanager +def block_sync_ops(): + old_val = sync_cursor_ops_blocked.get() + sync_cursor_ops_blocked.set(True) + try: + print("Started blocking sync ops.") + yield + finally: + sync_cursor_ops_blocked.set(old_val) + print("Stopped blocking sync ops.") + + +# XXX experimentation +@contextmanager +def unblock_sync_ops(): + old_val = sync_cursor_ops_blocked.get() + sync_cursor_ops_blocked.set(False) + try: + yield + finally: + sync_cursor_ops_blocked.set(old_val) + class CursorWrapper: def __init__(self, cursor, db): @@ -21,6 +68,10 @@ def __init__(self, cursor, db): WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"]) + # XXX experimentation + SYNC_BLOCK = {"close"} + # XXX experimentation + SAFE_LIST = set() APPS_NOT_READY_WARNING_MSG = ( "Accessing the database during app initialization is discouraged. To fix this " "warning, avoid executing queries in AppConfig.ready() or when your app " @@ -28,6 +79,18 @@ def __init__(self, cursor, db): ) def __getattr__(self, attr): + # XXX experimentation + # (the point here is being able to focus on a chunk of code in a specific + # way to identify if something is unintentionally falling back to sync ops) + if sync_cursor_ops_blocked.get(): + if attr in CursorWrapper.WRAP_ERROR_ATTRS: + raise ValueError("Sync operations blocked!") + elif attr in CursorWrapper.SYNC_BLOCK: + raise ValueError("Sync operations blocked!") + elif attr in CursorWrapper.SAFE_LIST: + pass + else: + print(f"CursorWrapper.{attr} accessed") cursor_attr = getattr(self.cursor, attr) if attr in CursorWrapper.WRAP_ERROR_ATTRS: return self.db.wrap_database_errors(cursor_attr) @@ -114,6 +177,97 @@ def _executemany(self, sql, param_list, *ignored_wrapper_args): return self.cursor.executemany(sql, param_list) +class AsyncCursorCtx: + """ + Asynchronous context manager to hold an async cursor. + """ + + def __init__(self, db, name=None): + self.db = db + self.name = name + self.wrap_database_errors = self.db.wrap_database_errors + + async def __aenter__(self) -> "AsyncCursorWrapper": + await self.db.aclose_if_health_check_failed() + await self.db.aensure_connection() + self.wrap_database_errors.__enter__() + return self.db._aprepare_cursor(self.db.create_async_cursor(self.name)) + + async def __aexit__(self, type, value, traceback): + self.wrap_database_errors.__exit__(type, value, traceback) + + +class AsyncCursorWrapper(CursorWrapper): + async def _aexecute(self, sql, params, *ignored_wrapper_args): + # Raise a warning during app initialization (stored_app_configs is only + # ever set during testing). + if not apps.ready and not apps.stored_app_configs: + warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning) + self.db.validate_no_broken_transaction() + with self.db.wrap_database_errors: + if params is None: + # params default might be backend specific. + return await self.cursor.execute(sql) + else: + return await self.cursor.execute(sql, params) + + async def _aexecute_with_wrappers(self, sql, params, many, executor): + context = {"connection": self.db, "cursor": self} + for wrapper in reversed(self.db.execute_wrappers): + executor = functools.partial(wrapper, executor) + return await executor(sql, params, many, context) + + async def aexecute(self, sql, params=None): + return await self._aexecute_with_wrappers( + sql, params, many=False, executor=self._aexecute + ) + + async def _aexecutemany(self, sql, param_list, *ignored_wrapper_args): + # Raise a warning during app initialization (stored_app_configs is only + # ever set during testing). + if not apps.ready and not apps.stored_app_configs: + warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning) + self.db.validate_no_broken_transaction() + with self.db.wrap_database_errors: + return await self.cursor.executemany(sql, param_list) + + async def aexecutemany(self, sql, param_list): + return await self._aexecute_with_wrappers( + sql, param_list, many=True, executor=self._aexecutemany + ) + + async def afetchone(self, *args, **kwargs): + return await self.cursor.fetchone(*args, **kwargs) + + async def afetchmany(self, *args, **kwargs): + return await self.cursor.fetchmany(*args, **kwargs) + + async def afetchall(self, *args, **kwargs): + return await self.cursor.fetchall(*args, **kwargs) + + async def acopy(self, *args, **kwargs): + return await self.cursor.copy(*args, **kwargs) + + async def astream(self, *args, **kwargs): + return await self.cursor.stream(*args, **kwargs) + + async def ascroll(self, *args, **kwargs): + return await self.cursor.ascroll(*args, **kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, type, value, traceback): + try: + await self.aclose() + except self.db.Database.Error: + pass + + async def aclose(self): + with unblock_sync_ops(): + await self.close() + + class CursorDebugWrapper(CursorWrapper): # XXX callproc isn't instrumented at this time. @@ -163,6 +317,57 @@ def debug_sql( ) +class AsyncCursorDebugWrapper(AsyncCursorWrapper): + # XXX callproc isn't instrumented at this time. + + async def aexecute(self, sql, params=None): + with self.debug_sql(sql, params, use_last_executed_query=True): + return await super().aexecute(sql, params) + + async def aexecutemany(self, sql, param_list): + with self.debug_sql(sql, param_list, many=True): + return await super().aexecutemany(sql, param_list) + + @contextmanager + def debug_sql( + self, sql=None, params=None, use_last_executed_query=False, many=False + ): + start = time.monotonic() + try: + yield + finally: + stop = time.monotonic() + duration = stop - start + if use_last_executed_query: + sql = self.db.ops.last_executed_query(self.cursor, sql, params) + try: + times = len(params) if many else "" + except TypeError: + # params could be an iterator. + times = "?" + self.db.queries_log.append( + { + "sql": "%s times: %s" % (times, sql) if many else sql, + "time": "%.3f" % duration, + "async": True, + } + ) + logger.debug( + "(%.3f) %s; args=%s; alias=%s; async=True", + duration, + sql, + params, + self.db.alias, + extra={ + "duration": duration, + "sql": sql, + "params": params, + "alias": self.db.alias, + "async": True, + }, + ) + + @contextmanager def debug_transaction(connection, sql): start = time.monotonic() @@ -176,18 +381,21 @@ def debug_transaction(connection, sql): { "sql": "%s" % sql, "time": "%.3f" % duration, + "async": connection.supports_async, } ) logger.debug( - "(%.3f) %s; args=%s; alias=%s", + "(%.3f) %s; args=%s; alias=%s; async=%s", duration, sql, None, connection.alias, + connection.supports_async, extra={ "duration": duration, "sql": sql, "alias": connection.alias, + "async": connection.supports_async, }, ) diff --git a/django/db/models/base.py b/django/db/models/base.py index 575365e11c73..f00ccbb3501c 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -5,7 +5,7 @@ from functools import partialmethod from itertools import chain -from asgiref.sync import sync_to_async +from asgiref.sync import async_to_sync, sync_to_async import django from django.apps import apps @@ -25,6 +25,7 @@ connection, connections, router, + should_use_sync_fallback, transaction, ) from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, Value @@ -50,6 +51,7 @@ pre_save, ) from django.db.models.utils import AltersData, make_model_tuple +from django.utils.codegen import from_codegen, generate_unasynced, ASYNC_TRUTH_MARKER from django.utils.encoding import force_str from django.utils.hashable import make_hashable from django.utils.text import capfirst, get_text_list @@ -585,6 +587,28 @@ def from_db(cls, db, field_names, values): new._state.db = db return new + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # the following are pairings of sync and async variants of model methods + # if a subclass overrides one of these without overriding the other, then + # we should make the other one fallback to using the overriding one + # + # for example: if I override save, then asave should call into my overridden + # save, instead of the default asave (which does it's own thing) + method_pairings = [ + ("save", "asave"), + ] + + for sync_variant, async_variant in method_pairings: + sync_defined = sync_variant in cls.__dict__ + async_defined = async_variant in cls.__dict__ + if sync_defined and not async_defined: + # async should fallback to sync + setattr(cls, async_variant, sync_to_async(getattr(cls, sync_variant))) + if not sync_defined and async_defined: + # sync should fallback to async + setattr(cls, sync_variant, async_to_sync(getattr(cls, async_variant))) + def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self) @@ -785,6 +809,7 @@ def serializable_value(self, field_name): return getattr(self, field_name) return getattr(self, field.attname) + @from_codegen def save( self, *, @@ -801,7 +826,6 @@ def save( that the "save" must be an SQL insert or update (or equivalent for non-SQL backends), respectively. Normally, they should not be set. """ - self._prepare_related_fields_for_save(operation_name="save") using = using or router.db_for_write(self.__class__, instance=self) @@ -854,8 +878,7 @@ def save( update_fields=update_fields, ) - save.alters_data = True - + @generate_unasynced() async def asave( self, *, @@ -864,13 +887,75 @@ async def asave( using=None, update_fields=None, ): - return await sync_to_async(self.save)( + """ + Save the current instance. Override this in a subclass if you want to + control the saving process. + + The 'force_insert' and 'force_update' parameters can be used to insist + that the "save" must be an SQL insert or update (or equivalent for + non-SQL backends), respectively. Normally, they should not be set. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.save)( + force_insert=force_insert, + force_update=force_update, + using=using, + update_fields=update_fields, + ) + self._prepare_related_fields_for_save(operation_name="save") + + using = using or router.db_for_write(self.__class__, instance=self) + if force_insert and (force_update or update_fields): + raise ValueError("Cannot force both insert and updating in model saving.") + + deferred_non_generated_fields = { + f.attname + for f in self._meta.concrete_fields + if f.attname not in self.__dict__ and f.generated is False + } + if update_fields is not None: + # If update_fields is empty, skip the save. We do also check for + # no-op saves later on for inheritance cases. This bailout is + # still needed for skipping signal sending. + if not update_fields: + return + + update_fields = frozenset(update_fields) + field_names = self._meta._non_pk_concrete_field_names + not_updatable_fields = update_fields.difference(field_names) + + if not_updatable_fields: + raise ValueError( + "The following fields do not exist in this model, are m2m " + "fields, primary keys, or are non-concrete fields: %s" + % ", ".join(not_updatable_fields) + ) + + # If saving to the same database, and this model is deferred, then + # automatically do an "update_fields" save on the loaded fields. + elif ( + not force_insert + and deferred_non_generated_fields + and using == self._state.db + ): + field_names = set() + pk_fields = self._meta.pk_fields + for field in self._meta.concrete_fields: + if field not in pk_fields and not hasattr(field, "through"): + field_names.add(field.attname) + loaded_fields = field_names.difference(deferred_non_generated_fields) + if loaded_fields: + update_fields = frozenset(loaded_fields) + + await self.asave_base( + using=using, force_insert=force_insert, force_update=force_update, - using=using, update_fields=update_fields, ) + save.alters_data = True asave.alters_data = True @classmethod @@ -893,6 +978,7 @@ def _validate_force_insert(cls, force_insert): ) return force_insert + @from_codegen def save_base( self, raw=False, @@ -939,6 +1025,7 @@ def save_base( parent_inserted = self._save_parents( cls, using, update_fields, force_insert ) + updated = self._save_table( raw, cls, @@ -963,8 +1050,82 @@ def save_base( using=using, ) + @generate_unasynced() + async def asave_base( + self, + raw=False, + force_insert=False, + force_update=False, + using=None, + update_fields=None, + ): + """ + Handle the parts of saving which should be done only once per save, + yet need to be done in raw saves, too. This includes some sanity + checks and signal sending. + + The 'raw' argument is telling save_base not to save any parent + models and not to do any changes to the values before save. This + is used by fixture loading. + """ + using = using or router.db_for_write(self.__class__, instance=self) + assert not (force_insert and (force_update or update_fields)) + assert update_fields is None or update_fields + cls = origin = self.__class__ + # Skip proxies, but keep the origin as the proxy model. + if cls._meta.proxy: + cls = cls._meta.concrete_model + meta = cls._meta + if not meta.auto_created: + pre_save.send( + sender=origin, + instance=self, + raw=raw, + using=using, + update_fields=update_fields, + ) + # A transaction isn't needed if one query is issued. + if meta.parents: + context_manager = transaction.atomic(using=using, savepoint=False) + else: + context_manager = transaction.mark_for_rollback_on_error(using=using) + async with context_manager: + parent_inserted = False + if not raw: + # Validate force insert only when parents are inserted. + force_insert = self._validate_force_insert(force_insert) + parent_inserted = await self._asave_parents( + cls, using, update_fields, force_insert + ) + + updated = await self._asave_table( + raw, + cls, + force_insert or parent_inserted, + force_update, + using, + update_fields, + ) + # Store the database on which the object was saved + self._state.db = using + # Once saved, this is no longer a to-be-added instance. + self._state.adding = False + + # Signal that the save is complete + if not meta.auto_created: + post_save.send( + sender=origin, + instance=self, + created=(not updated), + update_fields=update_fields, + raw=raw, + using=using, + ) + save_base.alters_data = True + asave_base.alters_data = True + @from_codegen def _save_parents( self, cls, using, update_fields, force_insert, updated_parents=None ): @@ -1012,6 +1173,55 @@ def _save_parents( field.delete_cached_value(self) return inserted + @generate_unasynced() + async def _asave_parents( + self, cls, using, update_fields, force_insert, updated_parents=None + ): + """Save all the parents of cls using values from self.""" + meta = cls._meta + inserted = False + if updated_parents is None: + updated_parents = {} + for parent, field in meta.parents.items(): + # Make sure the link fields are synced between parent and self. + if ( + field + and getattr(self, parent._meta.pk.attname) is None + and getattr(self, field.attname) is not None + ): + setattr(self, parent._meta.pk.attname, getattr(self, field.attname)) + if (parent_updated := updated_parents.get(parent)) is None: + parent_inserted = await self._asave_parents( + cls=parent, + using=using, + update_fields=update_fields, + force_insert=force_insert, + updated_parents=updated_parents, + ) + updated = await self._asave_table( + cls=parent, + using=using, + update_fields=update_fields, + force_insert=parent_inserted or issubclass(parent, force_insert), + ) + if not updated: + inserted = True + updated_parents[parent] = updated + elif not parent_updated: + inserted = True + # Set the parent's PK value to self. + if field: + setattr(self, field.attname, self._get_pk_val(parent._meta)) + # Since we didn't have an instance of the parent handy set + # attname directly, bypassing the descriptor. Invalidate + # the related object cache, in case it's been accidentally + # populated. A fresh instance will be re-built from the + # database if necessary. + if field.is_cached(self): + field.delete_cached_value(self) + return inserted + + @from_codegen def _save_table( self, raw=False, @@ -1108,6 +1318,104 @@ def _save_table( setattr(self, field.attname, value) return updated + @generate_unasynced() + async def _asave_table( + self, + raw=False, + cls=None, + force_insert=False, + force_update=False, + using=None, + update_fields=None, + ): + """ + Do the heavy-lifting involved in saving. Update or insert the data + for a single table. + """ + meta = cls._meta + pk_fields = meta.pk_fields + non_pks_non_generated = [ + f + for f in meta.local_concrete_fields + if f not in pk_fields and not f.generated + ] + + if update_fields: + non_pks_non_generated = [ + f + for f in non_pks_non_generated + if f.name in update_fields or f.attname in update_fields + ] + + if not self._is_pk_set(meta): + pk_val = meta.pk.get_pk_value_on_save(self) + setattr(self, meta.pk.attname, pk_val) + pk_set = self._is_pk_set(meta) + if not pk_set and (force_update or update_fields): + raise ValueError("Cannot force an update in save() with no primary key.") + updated = False + # Skip an UPDATE when adding an instance and primary key has a default. + if ( + not raw + and not force_insert + and not force_update + and self._state.adding + and all(f.has_default() or f.has_db_default() for f in meta.pk_fields) + ): + force_insert = True + # If possible, try an UPDATE. If that doesn't update anything, do an INSERT. + if pk_set and not force_insert: + base_qs = cls._base_manager.using(using) + values = [ + ( + f, + None, + (getattr(self, f.attname) if raw else f.pre_save(self, False)), + ) + for f in non_pks_non_generated + ] + forced_update = update_fields or force_update + pk_val = self._get_pk_val(meta) + updated = await self._ado_update( + base_qs, using, pk_val, values, update_fields, forced_update + ) + if force_update and not updated: + raise DatabaseError("Forced update did not affect any rows.") + if update_fields and not updated: + raise DatabaseError("Save with update_fields did not affect any rows.") + if not updated: + if meta.order_with_respect_to: + # If this is a model with an order_with_respect_to + # autopopulate the _order field + field = meta.order_with_respect_to + filter_args = field.get_filter_kwargs_for_object(self) + self._order = ( + cls._base_manager.using(using) + .filter(**filter_args) + .aggregate( + _order__max=Coalesce( + ExpressionWrapper( + Max("_order") + Value(1), output_field=IntegerField() + ), + Value(0), + ), + )["_order__max"] + ) + fields = [ + f + for f in meta.local_concrete_fields + if not f.generated and (pk_set or f is not meta.auto_field) + ] + returning_fields = meta.db_returning_fields + results = await self._ado_insert( + cls._base_manager, using, fields, returning_fields, raw + ) + if results: + for value, field in zip(results[0], returning_fields): + setattr(self, field.attname, value) + return updated + + @from_codegen def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): """ Try to update the model. Return True if the model was updated (if an @@ -1136,6 +1444,38 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat ) return filtered._update(values) > 0 + @generate_unasynced() + async def _ado_update( + self, base_qs, using, pk_val, values, update_fields, forced_update + ): + """ + Try to update the model. Return True if the model was updated (if an + update query was done and a matching row was found in the DB). + """ + filtered = base_qs.filter(pk=pk_val) + if not values: + # We can end up here when saving a model in inheritance chain where + # update_fields doesn't target any field in current model. In that + # case we just say the update succeeded. Another case ending up here + # is a model with just PK - in that case check that the PK still + # exists. + return update_fields is not None or await filtered.aexists() + if self._meta.select_on_save and not forced_update: + return ( + await filtered.aexists() + and + # It may happen that the object is deleted from the DB right after + # this check, causing the subsequent UPDATE to return zero matching + # rows. The same result can occur in some rare cases when the + # database returns zero despite the UPDATE being executed + # successfully (a row is matched and updated). In order to + # distinguish these two cases, the object's existence in the + # database is again checked for if the UPDATE query returns 0. + (await filtered._aupdate(values) > 0 or (await filtered.aexists())) + ) + return await filtered._aupdate(values) > 0 + + @from_codegen def _do_insert(self, manager, using, fields, returning_fields, raw): """ Do an INSERT. If returning_fields is defined then this method should @@ -1149,6 +1489,20 @@ def _do_insert(self, manager, using, fields, returning_fields, raw): raw=raw, ) + @generate_unasynced() + async def _ado_insert(self, manager, using, fields, returning_fields, raw): + """ + Do an INSERT. If returning_fields is defined then this method should + return the newly created data for the model. + """ + return await manager._ainsert( + [self], + fields=fields, + returning_fields=returning_fields, + using=using, + raw=raw, + ) + def _prepare_related_fields_for_save(self, operation_name, fields=None): # Ensure that a model instance without a PK hasn't been assigned to # a ForeignKey, GenericForeignKey or OneToOneField on this model. If diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index fd3d290a9632..f9deb5ddec80 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -5,6 +5,7 @@ from django.db import IntegrityError, connections, models, transaction from django.db.models import query_utils, signals, sql +from django.utils.codegen import from_codegen, generate_unasynced class ProtectedError(IntegrityError): @@ -113,6 +114,21 @@ def __init__(self, using, origin=None): # parent. self.dependencies = defaultdict(set) # {model: {models}} + @from_codegen + def bool(self, elts): + if hasattr(elts, "_afetch_then_len"): + return bool(elts._fetch_then_len()) + else: + return bool(elts) + + @generate_unasynced() + async def abool(self, elts): + if hasattr(elts, "_afetch_then_len"): + return bool(await elts._afetch_then_len()) + else: + return bool(elts) + + @from_codegen def add(self, objs, source=None, nullable=False, reverse_dependency=False): """ Add 'objs' to the collection of objects to be deleted. If the call is @@ -121,7 +137,8 @@ def add(self, objs, source=None, nullable=False, reverse_dependency=False): Return a list of all objects that were not already collected. """ - if not objs: + # XXX incorrect hack + if not self.bool(objs): return [] new_objs = [] model = objs[0].__class__ @@ -137,6 +154,32 @@ def add(self, objs, source=None, nullable=False, reverse_dependency=False): self.add_dependency(source, model, reverse_dependency=reverse_dependency) return new_objs + @generate_unasynced() + async def aadd(self, objs, source=None, nullable=False, reverse_dependency=False): + """ + Add 'objs' to the collection of objects to be deleted. If the call is + the result of a cascade, 'source' should be the model that caused it, + and 'nullable' should be set to True if the relation can be null. + + Return a list of all objects that were not already collected. + """ + # XXX incorrect hack + if not (await self.abool(objs)): + return [] + new_objs = [] + model = objs[0].__class__ + instances = self.data[model] + async for obj in objs: + if obj not in instances: + new_objs.append(obj) + instances.update(new_objs) + # Nullable relationships can be ignored -- they are nulled out before + # deleting, and therefore do not affect the order in which objects have + # to be deleted. + if source is not None and not nullable: + self.add_dependency(source, model, reverse_dependency=reverse_dependency) + return new_objs + def add_dependency(self, model, dependency, reverse_dependency=False): if reverse_dependency: model, dependency = dependency, model @@ -242,6 +285,7 @@ def get_del_batches(self, objs, fields): else: return [objs] + @from_codegen def collect( self, objs, @@ -396,6 +440,161 @@ def collect( set(chain.from_iterable(restricted_objects.values())), ) + @generate_unasynced() + async def acollect( + self, + objs, + source=None, + nullable=False, + collect_related=True, + source_attr=None, + reverse_dependency=False, + keep_parents=False, + fail_on_restricted=True, + ): + """ + Add 'objs' to the collection of objects to be deleted as well as all + parent instances. 'objs' must be a homogeneous iterable collection of + model instances (e.g. a QuerySet). If 'collect_related' is True, + related objects will be handled by their respective on_delete handler. + + If the call is the result of a cascade, 'source' should be the model + that caused it and 'nullable' should be set to True, if the relation + can be null. + + If 'reverse_dependency' is True, 'source' will be deleted before the + current model, rather than after. (Needed for cascading to parent + models, the one case in which the cascade follows the forwards + direction of an FK rather than the reverse direction.) + + If 'keep_parents' is True, data of parent model's will be not deleted. + + If 'fail_on_restricted' is False, error won't be raised even if it's + prohibited to delete such objects due to RESTRICT, that defers + restricted object checking in recursive calls where the top-level call + may need to collect more objects to determine whether restricted ones + can be deleted. + """ + if self.can_fast_delete(objs): + self.fast_deletes.append(objs) + return + new_objs = await self.aadd( + objs, source, nullable, reverse_dependency=reverse_dependency + ) + if not new_objs: + return + + model = new_objs[0].__class__ + + if not keep_parents: + # Recursively collect concrete model's parent models, but not their + # related objects. These will be found by meta.get_fields() + concrete_model = model._meta.concrete_model + for ptr in concrete_model._meta.parents.values(): + if ptr: + parent_objs = [getattr(obj, ptr.name) for obj in new_objs] + await self.acollect( + parent_objs, + source=model, + source_attr=ptr.remote_field.related_name, + collect_related=False, + reverse_dependency=True, + fail_on_restricted=False, + ) + if not collect_related: + return + + model_fast_deletes = defaultdict(list) + protected_objects = defaultdict(list) + for related in get_candidate_relations_to_delete(model._meta): + # Preserve parent reverse relationships if keep_parents=True. + if keep_parents and related.model in model._meta.all_parents: + continue + field = related.field + on_delete = field.remote_field.on_delete + if on_delete == DO_NOTHING: + continue + related_model = related.related_model + if self.can_fast_delete(related_model, from_field=field): + model_fast_deletes[related_model].append(field) + continue + batches = self.get_del_batches(new_objs, [field]) + for batch in batches: + sub_objs = self.related_objects(related_model, [field], batch) + # Non-referenced fields can be deferred if no signal receivers + # are connected for the related model as they'll never be + # exposed to the user. Skip field deferring when some + # relationships are select_related as interactions between both + # features are hard to get right. This should only happen in + # the rare cases where .related_objects is overridden anyway. + if not ( + sub_objs.query.select_related + or self._has_signal_listeners(related_model) + ): + referenced_fields = set( + chain.from_iterable( + (rf.attname for rf in rel.field.foreign_related_fields) + for rel in get_candidate_relations_to_delete( + related_model._meta + ) + ) + ) + sub_objs = sub_objs.only(*tuple(referenced_fields)) + if getattr(on_delete, "lazy_sub_objs", False) or sub_objs: + try: + on_delete(self, field, sub_objs, self.using) + except ProtectedError as error: + key = "'%s.%s'" % (field.model.__name__, field.name) + protected_objects[key] += error.protected_objects + if protected_objects: + raise ProtectedError( + "Cannot delete some instances of model %r because they are " + "referenced through protected foreign keys: %s." + % ( + model.__name__, + ", ".join(protected_objects), + ), + set(chain.from_iterable(protected_objects.values())), + ) + for related_model, related_fields in model_fast_deletes.items(): + batches = self.get_del_batches(new_objs, related_fields) + for batch in batches: + sub_objs = self.related_objects(related_model, related_fields, batch) + self.fast_deletes.append(sub_objs) + for field in model._meta.private_fields: + if hasattr(field, "bulk_related_objects"): + # It's something like generic foreign key. + sub_objs = field.bulk_related_objects(new_objs, self.using) + self.collect( + sub_objs, source=model, nullable=True, fail_on_restricted=False + ) + + if fail_on_restricted: + # Raise an error if collected restricted objects (RESTRICT) aren't + # candidates for deletion also collected via CASCADE. + for related_model, instances in self.data.items(): + self.clear_restricted_objects_from_set(related_model, instances) + for qs in self.fast_deletes: + self.clear_restricted_objects_from_queryset(qs.model, qs) + if self.restricted_objects.values(): + restricted_objects = defaultdict(list) + for related_model, fields in self.restricted_objects.items(): + for field, objs in fields.items(): + if objs: + key = "'%s.%s'" % (related_model.__name__, field.name) + restricted_objects[key] += objs + if restricted_objects: + raise RestrictedError( + "Cannot delete some instances of model %r because " + "they are referenced through restricted foreign keys: " + "%s." + % ( + model.__name__, + ", ".join(restricted_objects), + ), + set(chain.from_iterable(restricted_objects.values())), + ) + def related_objects(self, related_model, related_fields, objs): """ Get a QuerySet of the related model to objs via related fields. @@ -429,6 +628,7 @@ def sort(self): return self.data = {model: self.data[model] for model in sorted_models} + @from_codegen def delete(self): # sort instance collections for model, instances in self.data.items(): @@ -516,3 +716,92 @@ def delete(self): for instance in instances: setattr(instance, model._meta.pk.attname, None) return sum(deleted_counter.values()), dict(deleted_counter) + + @generate_unasynced() + async def adelete(self): + # sort instance collections + for model, instances in self.data.items(): + self.data[model] = sorted(instances, key=attrgetter("pk")) + + # if possible, bring the models in an order suitable for databases that + # don't support transactions or cannot defer constraint checks until the + # end of a transaction. + self.sort() + # number of objects deleted for each model label + deleted_counter = Counter() + + # Optimize for the case with a single obj and no dependencies + if len(self.data) == 1 and len(instances) == 1: + instance = list(instances)[0] + if self.can_fast_delete(instance): + with transaction.mark_for_rollback_on_error(self.using): + count = await sql.DeleteQuery(model).adelete_batch( + [instance.pk], self.using + ) + setattr(instance, model._meta.pk.attname, None) + return count, {model._meta.label: count} + + async with transaction.atomic(using=self.using, savepoint=False): + # send pre_delete signals + for model, obj in self.instances_with_model(): + if not model._meta.auto_created: + signals.pre_delete.send( + sender=model, + instance=obj, + using=self.using, + origin=self.origin, + ) + + # fast deletes + for qs in self.fast_deletes: + count = await qs._araw_delete(using=self.using) + if count: + deleted_counter[qs.model._meta.label] += count + + # update fields + for (field, value), instances_list in self.field_updates.items(): + updates = [] + objs = [] + for instances in instances_list: + if ( + isinstance(instances, models.QuerySet) + and instances._result_cache is None + ): + updates.append(instances) + else: + objs.extend(instances) + if updates: + combined_updates = reduce(or_, updates) + await combined_updates.aupdate(**{field.name: value}) + if objs: + model = objs[0].__class__ + query = sql.UpdateQuery(model) + await query.aupdate_batch( + list({obj.pk for obj in objs}), {field.name: value}, self.using + ) + + # reverse instance collections + for instances in self.data.values(): + instances.reverse() + + # delete instances + for model, instances in self.data.items(): + query = sql.DeleteQuery(model) + pk_list = [obj.pk for obj in instances] + count = await query.adelete_batch(pk_list, self.using) + if count: + deleted_counter[model._meta.label] += count + + if not model._meta.auto_created: + for obj in instances: + signals.post_delete.send( + sender=model, + instance=obj, + using=self.using, + origin=self.origin, + ) + + for model, instances in self.data.items(): + for instance in instances: + setattr(instance, model._meta.pk.attname, None) + return sum(deleted_counter.values()), dict(deleted_counter) diff --git a/django/db/models/query.py b/django/db/models/query.py index eb17624bf108..d97eebee71d1 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -16,8 +16,10 @@ DJANGO_VERSION_PICKLE_KEY, IntegrityError, NotSupportedError, + async_connections, connections, router, + should_use_sync_fallback, transaction, ) from django.db.models import AutoField, DateField, DateTimeField, Field, sql @@ -26,13 +28,14 @@ from django.db.models.expressions import Case, F, Value, When from django.db.models.functions import Cast, Trunc from django.db.models.query_utils import FilteredRelation, Q -from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, ROW_COUNT +from django.db.models.sql.constants import ROW_COUNT, CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.utils import ( AltersData, create_namedtuple_class, resolve_callables, ) from django.utils import timezone +from django.utils.codegen import ASYNC_TRUTH_MARKER, from_codegen, generate_unasynced from django.utils.functional import cached_property, partition # The maximum number of results to fetch in a get() query. @@ -50,11 +53,11 @@ def __init__( self.chunked_fetch = chunked_fetch self.chunk_size = chunk_size - async def _async_generator(self): + async def _sync_to_async_generator(self): # Generators don't actually start running until the first time you call # next() on them, so make the generator object in the async thread and # then repeatedly dispatch to it in a sync thread. - sync_generator = self.__iter__() + sync_generator = await sync_to_async(self.__iter__)() def next_slice(gen): return list(islice(gen, self.chunk_size)) @@ -66,6 +69,8 @@ def next_slice(gen): if len(chunk) < self.chunk_size: break + _async_generator = _sync_to_async_generator + # __aiter__() is a *synchronous* method that has to then return an # *asynchronous* iterator/generator. Thus, nest an async generator inside # it. @@ -75,13 +80,30 @@ def next_slice(gen): # be added to each Iterable subclass, but that needs some work in the # Compiler first. def __aiter__(self): - return self._async_generator() + # not clear to me if we need this fallback, to investigate + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return self._sync_to_async_generator() + else: + return self._agenerator() + + def __iter__(self): + return self._generator() + + def _generator(self): + raise NotImplementedError() + + def _agenerator(self): + raise NotImeplementedError() class ModelIterable(BaseIterable): """Iterable that yields a model instance for each row.""" def __iter__(self): + return self._generator() + + @from_codegen + def _generator(self): queryset = self.queryset db = queryset.db compiler = queryset.query.get_compiler(using=db) @@ -144,6 +166,79 @@ def __iter__(self): yield obj + @generate_unasynced() + async def _agenerator(self): + queryset = self.queryset + db = queryset.db + if ASYNC_TRUTH_MARKER: + compiler = queryset.query.aget_compiler(using=db) + else: + compiler = queryset.query.get_compiler(using=db) + # Execute the query. This will also fill compiler.select, klass_info, + # and annotations. + results = await compiler.aexecute_sql( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ) + select, klass_info, annotation_col_map = ( + compiler.select, + compiler.klass_info, + compiler.annotation_col_map, + ) + model_cls = klass_info["model"] + select_fields = klass_info["select_fields"] + model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 + init_list = [ + f[0].target.attname for f in select[model_fields_start:model_fields_end] + ] + related_populators = get_related_populators(klass_info, select, db) + known_related_objects = [ + ( + field, + related_objs, + operator.attrgetter( + *[ + ( + field.attname + if from_field == "self" + else queryset.model._meta.get_field(from_field).attname + ) + for from_field in field.from_fields + ] + ), + ) + for field, related_objs in queryset._known_related_objects.items() + ] + for row in await compiler.aresults_iter(results): + obj = model_cls.from_db( + db, init_list, row[model_fields_start:model_fields_end] + ) + for rel_populator in related_populators: + rel_populator.populate(row, obj) + if annotation_col_map: + for attr_name, col_pos in annotation_col_map.items(): + setattr(obj, attr_name, row[col_pos]) + + # Add the known related objects to the model. + for field, rel_objs, rel_getter in known_related_objects: + # Avoid overwriting objects loaded by, e.g., select_related(). + if field.is_cached(obj): + continue + rel_obj_id = rel_getter(obj) + try: + rel_obj = rel_objs[rel_obj_id] + except KeyError: + pass # May happen in qs1 | qs2 scenarios. + else: + setattr(obj, field.name, rel_obj) + + yield obj + + def __aiter__(self): + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return self._sync_to_async_generator() + else: + return self._agenerator() + class RawModelIterable(BaseIterable): """ @@ -361,10 +456,19 @@ def __repr__(self): data[-1] = "...(remaining elements truncated)..." return "<%s %r>" % (self.__class__.__name__, data) - def __len__(self): + @from_codegen + def _fetch_then_len(self): self._fetch_all() return len(self._result_cache) + @generate_unasynced() + async def _afetch_then_len(self): + await self._afetch_all() + return len(self._result_cache) + + def __len__(self): + return self._fetch_then_len() + def __iter__(self): """ The queryset iterator protocol uses three nested iterators in the @@ -387,7 +491,7 @@ def __aiter__(self): # Remember, __aiter__ itself is synchronous, it's the thing it returns # that is async! async def generator(): - await sync_to_async(self._fetch_all)() + await self._afetch_all() for item in self._result_cache: yield item @@ -561,6 +665,7 @@ async def aiterator(self, chunk_size=2000): async for item in iterable: yield item + @from_codegen def aggregate(self, *args, **kwargs): """ Return a dictionary containing the calculations (aggregation) @@ -586,9 +691,36 @@ def aggregate(self, *args, **kwargs): return self.query.chain().get_aggregation(self.db, kwargs) + @generate_unasynced() async def aaggregate(self, *args, **kwargs): - return await sync_to_async(self.aggregate)(*args, **kwargs) + """ + Return a dictionary containing the calculations (aggregation) + over the current queryset. + If args is present the expression is passed as a kwarg using + the Aggregate object's default alias. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.aggregate)(*args, **kwargs) + if self.query.distinct_fields: + raise NotImplementedError("aggregate() + distinct(fields) not implemented.") + self._validate_values_are_expressions( + (*args, *kwargs.values()), method_name="aggregate" + ) + for arg in args: + # The default_alias property raises TypeError if default_alias + # can't be set automatically or AttributeError if it isn't an + # attribute. + try: + arg.default_alias + except (AttributeError, TypeError): + raise TypeError("Complex aggregates require an alias") + kwargs[arg.default_alias] = arg + + return await self.query.chain().aget_aggregation(self.db, kwargs) + + @from_codegen def count(self): """ Perform a SELECT COUNT() and return the number of records as an @@ -602,14 +734,30 @@ def count(self): return self.query.get_count(using=self.db) + @generate_unasynced() async def acount(self): - return await sync_to_async(self.count)() + """ + Perform a SELECT COUNT() and return the number of records as an + integer. + + If the QuerySet is already fully cached, return the length of the + cached results set to avoid multiple SELECT COUNT(*) calls. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.count)() + if self._result_cache is not None: + return len(self._result_cache) + + return await self.query.aget_count(using=self.db) + @from_codegen def get(self, *args, **kwargs): """ Perform the query and return a single object matching the given keyword arguments. """ + if self.query.combinator and (args or kwargs): raise NotSupportedError( "Calling QuerySet.get(...) with filters after %s() is not " @@ -625,7 +773,7 @@ def get(self, *args, **kwargs): ): limit = MAX_GET_RESULTS clone.query.set_limits(high=limit) - num = len(clone) + num = clone._fetch_then_len() if num == 1: return clone._result_cache[0] if not num: @@ -640,9 +788,47 @@ def get(self, *args, **kwargs): ) ) + @generate_unasynced() async def aget(self, *args, **kwargs): - return await sync_to_async(self.get)(*args, **kwargs) + """ + Perform the query and return a single object matching the given + keyword arguments. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.get)(*args, **kwargs) + if self.query.combinator and (args or kwargs): + raise NotSupportedError( + "Calling QuerySet.get(...) with filters after %s() is not " + "supported." % self.query.combinator + ) + clone = self._chain() if self.query.combinator else self.filter(*args, **kwargs) + if self.query.can_filter() and not self.query.distinct_fields: + clone = clone.order_by() + limit = None + if ( + not clone.query.select_for_update + or connections[clone.db].features.supports_select_for_update_with_limit + ): + limit = MAX_GET_RESULTS + clone.query.set_limits(high=limit) + num = await clone._afetch_then_len() + if num == 1: + return clone._result_cache[0] + if not num: + raise self.model.DoesNotExist( + "%s matching query does not exist." % self.model._meta.object_name + ) + raise self.model.MultipleObjectsReturned( + "get() returned more than one %s -- it returned %s!" + % ( + self.model._meta.object_name, + num if not limit or num < limit else "more than %s" % (limit - 1), + ) + ) + + @from_codegen def create(self, **kwargs): """ Create a new object with the given kwargs, saving it to the database @@ -662,11 +848,30 @@ def create(self, **kwargs): obj.save(force_insert=True, using=self.db) return obj - create.alters_data = True - + @generate_unasynced() async def acreate(self, **kwargs): - return await sync_to_async(self.create)(**kwargs) + """ + Create a new object with the given kwargs, saving it to the database + and returning the created object. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.create)(**kwargs) + reverse_one_to_one_fields = frozenset(kwargs).intersection( + self.model._meta._reverse_one_to_one_field_names + ) + if reverse_one_to_one_fields: + raise ValueError( + "The following fields do not exist in this model: %s" + % ", ".join(reverse_one_to_one_fields) + ) + + obj = self.model(**kwargs) + self._for_write = True + await obj.asave(force_insert=True, using=self.db) + return obj + create.alters_data = True acreate.alters_data = True def _prepare_for_bulk_create(self, objs): @@ -741,6 +946,7 @@ def _check_bulk_create_options( return OnConflict.UPDATE return None + @from_codegen def bulk_create( self, objs, @@ -841,8 +1047,7 @@ def bulk_create( return objs - bulk_create.alters_data = True - + @generate_unasynced() async def abulk_create( self, objs, @@ -852,17 +1057,110 @@ async def abulk_create( update_fields=None, unique_fields=None, ): - return await sync_to_async(self.bulk_create)( - objs=objs, - batch_size=batch_size, - ignore_conflicts=ignore_conflicts, - update_conflicts=update_conflicts, - update_fields=update_fields, - unique_fields=unique_fields, + """ + Insert each of the instances into the database. Do *not* call + save() on each of the instances, do not send any pre/post_save + signals, and do not set the primary key attribute if it is an + autoincrement field (except if features.can_return_rows_from_bulk_insert=True). + Multi-table models are not supported. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.bulk_create)( + objs=objs, + batch_size=batch_size, + ignore_conflicts=ignore_conflicts, + update_conflicts=update_conflicts, + update_fields=update_fields, + unique_fields=unique_fields, + ) + # When you bulk insert you don't get the primary keys back (if it's an + # autoincrement, except if can_return_rows_from_bulk_insert=True), so + # you can't insert into the child tables which references this. There + # are two workarounds: + # 1) This could be implemented if you didn't have an autoincrement pk + # 2) You could do it by doing O(n) normal inserts into the parent + # tables to get the primary keys back and then doing a single bulk + # insert into the childmost table. + # We currently set the primary keys on the objects when using + # PostgreSQL via the RETURNING ID clause. It should be possible for + # Oracle as well, but the semantics for extracting the primary keys is + # trickier so it's not done yet. + if batch_size is not None and batch_size <= 0: + raise ValueError("Batch size must be a positive integer.") + # Check that the parents share the same concrete model with the our + # model to detect the inheritance pattern ConcreteGrandParent -> + # MultiTableParent -> ProxyChild. Simply checking self.model._meta.proxy + # would not identify that case as involving multiple tables. + for parent in self.model._meta.all_parents: + if parent._meta.concrete_model is not self.model._meta.concrete_model: + raise ValueError("Can't bulk create a multi-table inherited model") + if not objs: + return objs + opts = self.model._meta + if unique_fields: + # Primary key is allowed in unique_fields. + unique_fields = [ + self.model._meta.get_field(opts.pk.name if name == "pk" else name) + for name in unique_fields + ] + if update_fields: + update_fields = [self.model._meta.get_field(name) for name in update_fields] + on_conflict = self._check_bulk_create_options( + ignore_conflicts, + update_conflicts, + update_fields, + unique_fields, ) + self._for_write = True + fields = [f for f in opts.concrete_fields if not f.generated] + objs = list(objs) + self._prepare_for_bulk_create(objs) + async with transaction.atomic(using=self.db, savepoint=False): + objs_without_pk, objs_with_pk = partition(lambda o: o._is_pk_set(), objs) + if objs_with_pk: + returned_columns = await self._abatched_insert( + objs_with_pk, + fields, + batch_size, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + ) + for obj_with_pk, results in zip(objs_with_pk, returned_columns): + for result, field in zip(results, opts.db_returning_fields): + if field != opts.pk: + setattr(obj_with_pk, field.attname, result) + for obj_with_pk in objs_with_pk: + obj_with_pk._state.adding = False + obj_with_pk._state.db = self.db + if objs_without_pk: + fields = [f for f in fields if not isinstance(f, AutoField)] + returned_columns = await self._abatched_insert( + objs_without_pk, + fields, + batch_size, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + ) + connection = connections[self.db] + if ( + connection.features.can_return_rows_from_bulk_insert + and on_conflict is None + ): + assert len(returned_columns) == len(objs_without_pk) + for obj_without_pk, results in zip(objs_without_pk, returned_columns): + for result, field in zip(results, opts.db_returning_fields): + setattr(obj_without_pk, field.attname, result) + obj_without_pk._state.adding = False + obj_without_pk._state.db = self.db + + return objs abulk_create.alters_data = True + @from_codegen def bulk_update(self, objs, fields, batch_size=None): """ Update the given fields in each of the given objects in the database. @@ -918,17 +1216,75 @@ def bulk_update(self, objs, fields, batch_size=None): rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs) return rows_updated - bulk_update.alters_data = True - + @generate_unasynced() async def abulk_update(self, objs, fields, batch_size=None): - return await sync_to_async(self.bulk_update)( - objs=objs, - fields=fields, - batch_size=batch_size, - ) + """ + Update the given fields in each of the given objects in the database. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.bulk_update)( + objs=objs, + fields=fields, + batch_size=batch_size, + ) + if batch_size is not None and batch_size <= 0: + raise ValueError("Batch size must be a positive integer.") + if not fields: + raise ValueError("Field names must be given to bulk_update().") + objs = tuple(objs) + if not all(obj._is_pk_set() for obj in objs): + raise ValueError("All bulk_update() objects must have a primary key set.") + fields = [self.model._meta.get_field(name) for name in fields] + if any(not f.concrete or f.many_to_many for f in fields): + raise ValueError("bulk_update() can only be used with concrete fields.") + all_pk_fields = set(self.model._meta.pk_fields) + for parent in self.model._meta.all_parents: + all_pk_fields.update(parent._meta.pk_fields) + if any(f in all_pk_fields for f in fields): + raise ValueError("bulk_update() cannot be used with primary key fields.") + if not objs: + return 0 + for obj in objs: + obj._prepare_related_fields_for_save( + operation_name="bulk_update", fields=fields + ) + # PK is used twice in the resulting update query, once in the filter + # and once in the WHEN. Each field will also have one CAST. + self._for_write = True + connection = connections[self.db] + max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) + batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size + requires_casting = connection.features.requires_casted_case_in_updates + batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) + updates = [] + for batch_objs in batches: + update_kwargs = {} + for field in fields: + when_statements = [] + for obj in batch_objs: + attr = getattr(obj, field.attname) + if not hasattr(attr, "resolve_expression"): + attr = Value(attr, output_field=field) + when_statements.append(When(pk=obj.pk, then=attr)) + case_statement = Case(*when_statements, output_field=field) + if requires_casting: + case_statement = Cast(case_statement, output_field=field) + update_kwargs[field.attname] = case_statement + updates.append(([obj.pk for obj in batch_objs], update_kwargs)) + rows_updated = 0 + queryset = self.using(self.db) + async with transaction.atomic(using=self.db, savepoint=False): + for pks, update_kwargs in updates: + rows_updated += await queryset.filter(pk__in=pks).aupdate( + **update_kwargs + ) + return rows_updated + bulk_update.alters_data = True abulk_update.alters_data = True + @from_codegen def get_or_create(self, defaults=None, **kwargs): """ Look up an object with the given kwargs, creating one if necessary. @@ -954,16 +1310,36 @@ def get_or_create(self, defaults=None, **kwargs): pass raise - get_or_create.alters_data = True - + @generate_unasynced() async def aget_or_create(self, defaults=None, **kwargs): - return await sync_to_async(self.get_or_create)( - defaults=defaults, - **kwargs, - ) + """ + Look up an object with the given kwargs, creating one if necessary. + Return a tuple of (object, created), where created is a boolean + specifying whether an object was created. + """ + # The get() needs to be targeted at the write database in order + # to avoid potential transaction consistency problems. + self._for_write = True + try: + return (await self.aget(**kwargs)), False + except self.model.DoesNotExist: + params = self._extract_model_params(defaults, **kwargs) + # Try to create an object using passed params. + try: + async with transaction.atomic(using=self.db): + params = dict(resolve_callables(params)) + return (await self.acreate(**params)), True + except IntegrityError: + try: + return (await self.aget(**kwargs)), False + except self.model.DoesNotExist: + pass + raise + get_or_create.alters_data = True aget_or_create.alters_data = True + @from_codegen def update_or_create(self, defaults=None, create_defaults=None, **kwargs): """ Look up an object with the given kwargs, updating one with defaults @@ -1010,15 +1386,61 @@ def update_or_create(self, defaults=None, create_defaults=None, **kwargs): obj.save(using=self.db) return obj, False - update_or_create.alters_data = True - + @generate_unasynced() async def aupdate_or_create(self, defaults=None, create_defaults=None, **kwargs): - return await sync_to_async(self.update_or_create)( - defaults=defaults, - create_defaults=create_defaults, - **kwargs, - ) + """ + Look up an object with the given kwargs, updating one with defaults + if it exists, otherwise create a new one. Optionally, an object can + be created with different values than defaults by using + create_defaults. + Return a tuple (object, created), where created is a boolean + specifying whether an object was created. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.update_or_create)( + defaults=defaults, + create_defaults=create_defaults, + **kwargs, + ) + update_defaults = defaults or {} + if create_defaults is None: + create_defaults = update_defaults + self._for_write = True + async with transaction.atomic(using=self.db): + # Lock the row so that a concurrent update is blocked until + # update_or_create() has performed its save. + obj, created = await self.select_for_update().aget_or_create( + create_defaults, **kwargs + ) + if created: + return obj, created + for k, v in resolve_callables(update_defaults): + setattr(obj, k, v) + + update_fields = set(update_defaults) + concrete_field_names = self.model._meta._non_pk_concrete_field_names + # update_fields does not support non-concrete fields. + if concrete_field_names.issuperset(update_fields): + # Add fields which are set on pre_save(), e.g. auto_now fields. + # This is to maintain backward compatibility as these fields + # are not updated unless explicitly specified in the + # update_fields list. + pk_fields = self.model._meta.pk_fields + for field in self.model._meta.local_concrete_fields: + if not ( + field in pk_fields or field.__class__.pre_save is Field.pre_save + ): + update_fields.add(field.name) + if field.name != field.attname: + update_fields.add(field.attname) + await obj.asave(using=self.db, update_fields=update_fields) + else: + await obj.asave(using=self.db) + return obj, False + + update_or_create.alters_data = True aupdate_or_create.alters_data = True def _extract_model_params(self, defaults, **kwargs): @@ -1048,6 +1470,7 @@ def _extract_model_params(self, defaults, **kwargs): ) return params + @from_codegen def _earliest(self, *fields): """ Return the earliest object according to fields (if given) or by the @@ -1070,14 +1493,45 @@ def _earliest(self, *fields): obj.query.add_ordering(*order_by) return obj.get() + @generate_unasynced() + async def _aearliest(self, *fields): + """ + Return the earliest object according to fields (if given) or by the + model's Meta.get_latest_by. + """ + if fields: + order_by = fields + else: + order_by = getattr(self.model._meta, "get_latest_by") + if order_by and not isinstance(order_by, (tuple, list)): + order_by = (order_by,) + if order_by is None: + raise ValueError( + "earliest() and latest() require either fields as positional " + "arguments or 'get_latest_by' in the model's Meta." + ) + obj = self._chain() + obj.query.set_limits(high=1) + obj.query.clear_ordering(force=True) + obj.query.add_ordering(*order_by) + return await obj.aget() + + @from_codegen def earliest(self, *fields): if self.query.is_sliced: raise TypeError("Cannot change a query once a slice has been taken.") return self._earliest(*fields) + @generate_unasynced() async def aearliest(self, *fields): - return await sync_to_async(self.earliest)(*fields) + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.earliest)(*fields) + if self.query.is_sliced: + raise TypeError("Cannot change a query once a slice has been taken.") + return await self._aearliest(*fields) + @from_codegen def latest(self, *fields): """ Return the latest object according to fields (if given) or by the @@ -1087,9 +1541,20 @@ def latest(self, *fields): raise TypeError("Cannot change a query once a slice has been taken.") return self.reverse()._earliest(*fields) + @generate_unasynced() async def alatest(self, *fields): - return await sync_to_async(self.latest)(*fields) + """ + Return the latest object according to fields (if given) or by the + model's Meta.get_latest_by. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.latest)(*fields) + if self.query.is_sliced: + raise TypeError("Cannot change a query once a slice has been taken.") + return await self.reverse()._aearliest(*fields) + @from_codegen def first(self): """Return the first object of a query or None if no match is found.""" if self.ordered: @@ -1100,9 +1565,21 @@ def first(self): for obj in queryset[:1]: return obj + @generate_unasynced() async def afirst(self): - return await sync_to_async(self.first)() + """Return the first object of a query or None if no match is found.""" + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.first)() + if self.ordered: + queryset = self + else: + self._check_ordering_first_last_queryset_aggregation(method="first") + queryset = self.order_by("pk") + async for obj in queryset[:1]: + return obj + @from_codegen def last(self): """Return the last object of a query or None if no match is found.""" if self.ordered: @@ -1113,9 +1590,21 @@ def last(self): for obj in queryset[:1]: return obj + @generate_unasynced() async def alast(self): - return await sync_to_async(self.last)() + """Return the last object of a query or None if no match is found.""" + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.last)() + if self.ordered: + queryset = self.reverse() + else: + self._check_ordering_first_last_queryset_aggregation(method="last") + queryset = self.order_by("-pk") + async for obj in queryset[:1]: + return obj + @from_codegen def in_bulk(self, id_list=None, *, field_name="pk"): """ Return a dictionary mapping each of the given IDs to the object with @@ -1160,12 +1649,58 @@ def in_bulk(self, id_list=None, *, field_name="pk"): qs = self._chain() return {getattr(obj, field_name): obj for obj in qs} + @generate_unasynced() async def ain_bulk(self, id_list=None, *, field_name="pk"): - return await sync_to_async(self.in_bulk)( - id_list=id_list, - field_name=field_name, - ) + """ + Return a dictionary mapping each of the given IDs to the object with + that ID. If `id_list` isn't provided, evaluate the entire QuerySet. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.in_bulk)( + id_list=id_list, + field_name=field_name, + ) + if self.query.is_sliced: + raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().") + if not issubclass(self._iterable_class, ModelIterable): + raise TypeError("in_bulk() cannot be used with values() or values_list().") + opts = self.model._meta + unique_fields = [ + constraint.fields[0] + for constraint in opts.total_unique_constraints + if len(constraint.fields) == 1 + ] + if ( + field_name != "pk" + and not opts.get_field(field_name).unique + and field_name not in unique_fields + and self.query.distinct_fields != (field_name,) + ): + raise ValueError( + "in_bulk()'s field_name must be a unique field but %r isn't." + % field_name + ) + if id_list is not None: + if not id_list: + return {} + filter_key = "{}__in".format(field_name) + batch_size = connections[self.db].features.max_query_params + id_list = tuple(id_list) + # If the database has a limit on the number of query parameters + # (e.g. SQLite), retrieve objects in batches if necessary. + if batch_size and batch_size < len(id_list): + qs = () + for offset in range(0, len(id_list), batch_size): + batch = id_list[offset : offset + batch_size] + qs += tuple(self.filter(**{filter_key: batch})) + else: + qs = self.filter(**{filter_key: id_list}) + else: + qs = self._chain() + return {getattr(obj, field_name): obj async for obj in qs} + @from_codegen def delete(self): """Delete the records in the current QuerySet.""" self._not_support_combined_queries("delete") @@ -1196,15 +1731,47 @@ def delete(self): self._result_cache = None return num_deleted, num_deleted_per_model + @generate_unasynced() + async def adelete(self): + """Delete the records in the current QuerySet.""" + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.delete)() + self._not_support_combined_queries("delete") + if self.query.is_sliced: + raise TypeError("Cannot use 'limit' or 'offset' with delete().") + if self.query.distinct_fields: + raise TypeError("Cannot call delete() after .distinct(*fields).") + if self._fields is not None: + raise TypeError("Cannot call delete() after .values() or .values_list()") + + del_query = self._chain() + + # The delete is actually 2 queries - one to find related objects, + # and one to delete. Make sure that the discovery of related + # objects is performed on the same database as the deletion. + del_query._for_write = True + + # Disable non-supported fields. + del_query.query.select_for_update = False + del_query.query.select_related = False + del_query.query.clear_ordering(force=True) + + collector = Collector(using=del_query.db, origin=self) + await collector.acollect(del_query) + num_deleted, num_deleted_per_model = await collector.adelete() + + # Clear the result cache, in case this QuerySet gets reused. + self._result_cache = None + return num_deleted, num_deleted_per_model + delete.alters_data = True delete.queryset_only = True - async def adelete(self): - return await sync_to_async(self.delete)() - adelete.alters_data = True adelete.queryset_only = True + @from_codegen def _raw_delete(self, using): """ Delete objects found from the given queryset in single direct SQL @@ -1214,8 +1781,20 @@ def _raw_delete(self, using): query.__class__ = sql.DeleteQuery return query.get_compiler(using).execute_sql(ROW_COUNT) + @generate_unasynced() + async def _araw_delete(self, using): + """ + Delete objects found from the given queryset in single direct SQL + query. No signals are sent and there is no protection for cascades. + """ + query = self.query.clone() + query.__class__ = sql.DeleteQuery + return await query.aget_compiler(using).aexecute_sql(ROW_COUNT) + _raw_delete.alters_data = True + _araw_delete.alters_data = True + @from_codegen def update(self, **kwargs): """ Update all elements in the current QuerySet, setting all the given @@ -1250,18 +1829,59 @@ def update(self, **kwargs): # Clear any annotations so that they won't be present in subqueries. query.annotations = {} - with transaction.mark_for_rollback_on_error(using=self.db): + with transaction.amark_for_rollback_on_error(using=self.db): rows = query.get_compiler(self.db).execute_sql(ROW_COUNT) self._result_cache = None return rows - update.alters_data = True - + @generate_unasynced() async def aupdate(self, **kwargs): - return await sync_to_async(self.update)(**kwargs) + """ + Update all elements in the current QuerySet, setting all the given + fields to the appropriate values. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.update)(**kwargs) + self._not_support_combined_queries("update") + if self.query.is_sliced: + raise TypeError("Cannot update a query once a slice has been taken.") + self._for_write = True + query = self.query.chain(sql.UpdateQuery) + query.add_update_values(kwargs) + + # Inline annotations in order_by(), if possible. + new_order_by = [] + for col in query.order_by: + alias = col + descending = False + if isinstance(alias, str) and alias.startswith("-"): + alias = alias.removeprefix("-") + descending = True + if annotation := query.annotations.get(alias): + if getattr(annotation, "contains_aggregate", False): + raise exceptions.FieldError( + f"Cannot update when ordering by an aggregate: {annotation}" + ) + if descending: + annotation = annotation.desc() + new_order_by.append(annotation) + else: + new_order_by.append(col) + query.order_by = tuple(new_order_by) + + # Clear any annotations so that they won't be present in subqueries. + query.annotations = {} + async with transaction.amark_for_rollback_on_error(using=self.db): + rows = await query.aget_compiler(self.db).aexecute_sql(ROW_COUNT) + self._result_cache = None + return rows + + update.alters_data = True aupdate.alters_data = True + @from_codegen def _update(self, values): """ A version of update() that accepts field objects instead of field names. @@ -1278,9 +1898,29 @@ def _update(self, values): self._result_cache = None return query.get_compiler(self.db).execute_sql(ROW_COUNT) + @generate_unasynced() + async def _aupdate(self, values): + """ + A version of update() that accepts field objects instead of field names. + Used primarily for model saving and not intended for use by general + code (it requires too much poking around at model internals to be + useful at that level). + """ + if self.query.is_sliced: + raise TypeError("Cannot update a query once a slice has been taken.") + query = self.query.chain(sql.UpdateQuery) + query.add_update_fields(values) + # Clear any annotations so that they won't be present in subqueries. + query.annotations = {} + self._result_cache = None + return await query.aget_compiler(self.db).aexecute_sql(ROW_COUNT) + + _aupdate.alters_data = True + _aupdate.queryset_only = False _update.alters_data = True _update.queryset_only = False + @from_codegen def exists(self): """ Return True if the QuerySet would have any results, False otherwise. @@ -1289,9 +1929,19 @@ def exists(self): return self.query.has_results(using=self.db) return bool(self._result_cache) + @generate_unasynced() async def aexists(self): - return await sync_to_async(self.exists)() + """ + Return True if the QuerySet would have any results, False otherwise. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.exists)() + if self._result_cache is None: + return await self.query.ahas_results(using=self.db) + return bool(self._result_cache) + @from_codegen def contains(self, obj): """ Return True if the QuerySet contains the provided obj, @@ -1313,14 +1963,45 @@ def contains(self, obj): return obj in self._result_cache return self.filter(pk=obj.pk).exists() + @generate_unasynced() async def acontains(self, obj): - return await sync_to_async(self.contains)(obj=obj) + """ + Return True if the QuerySet contains the provided obj, + False otherwise. + """ + if ASYNC_TRUTH_MARKER: + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.contains)(obj=obj) + self._not_support_combined_queries("contains") + if self._fields is not None: + raise TypeError( + "Cannot call QuerySet.contains() after .values() or .values_list()." + ) + try: + if obj._meta.concrete_model != self.model._meta.concrete_model: + return False + except AttributeError: + raise TypeError("'obj' must be a model instance.") + if not obj._is_pk_set(): + raise ValueError("QuerySet.contains() cannot be used on unsaved objects.") + if self._result_cache is not None: + return obj in self._result_cache + return await self.filter(pk=obj.pk).aexists() + @from_codegen def _prefetch_related_objects(self): # This method can only be called once the result cache has been filled. prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups) self._prefetch_done = True + @generate_unasynced() + async def _aprefetch_related_objects(self): + # This method can only be called once the result cache has been filled. + await aprefetch_related_objects( + self._result_cache, *self._prefetch_related_lookups + ) + self._prefetch_done = True + def explain(self, *, format=None, **options): """ Runs an EXPLAIN on the SQL query this QuerySet would perform, and @@ -1820,6 +2501,7 @@ def db(self): # PRIVATE METHODS # ################### + @from_codegen def _insert( self, objs, @@ -1847,9 +2529,45 @@ def _insert( query.insert_values(fields, objs, raw=raw) return query.get_compiler(using=using).execute_sql(returning_fields) + ################### + # PRIVATE METHODS # + ################### + + @generate_unasynced() + async def _ainsert( + self, + objs, + fields, + returning_fields=None, + raw=False, + using=None, + on_conflict=None, + update_fields=None, + unique_fields=None, + ): + """ + Insert a new record for the given model. This provides an interface to + the InsertQuery class and is how Model.save() is implemented. + """ + self._for_write = True + if using is None: + using = self.db + query = sql.InsertQuery( + self.model, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + ) + query.insert_values(fields, objs, raw=raw) + return await query.aget_compiler(using=using).aexecute_sql(returning_fields) + _insert.alters_data = True _insert.queryset_only = False + _ainsert.alters_data = True + _ainsert.queryset_only = False + + @from_codegen def _batched_insert( self, objs, @@ -1894,6 +2612,54 @@ def _batched_insert( ) return inserted_rows + @generate_unasynced() + async def _abatched_insert( + self, + objs, + fields, + batch_size, + on_conflict=None, + update_fields=None, + unique_fields=None, + ): + """ + Helper method for bulk_create() to insert objs one batch at a time. + """ + if ASYNC_TRUTH_MARKER: + connection = async_connections.get_connection(self.db) + else: + connection = connections[self.db] + ops = connection.ops + max_batch_size = max(ops.bulk_batch_size(fields, objs), 1) + batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size + inserted_rows = [] + bulk_return = connection.features.can_return_rows_from_bulk_insert + for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]: + if bulk_return and ( + on_conflict is None or on_conflict == OnConflict.UPDATE + ): + inserted_rows.extend( + await self._ainsert( + item, + fields=fields, + using=self.db, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + returning_fields=self.model._meta.db_returning_fields, + ) + ) + else: + await self._ainsert( + item, + fields=fields, + using=self.db, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + ) + return inserted_rows + def _chain(self): """ Return a copy of the current QuerySet that's ready for another @@ -1924,12 +2690,20 @@ def _clone(self): c._fields = self._fields return c + @from_codegen def _fetch_all(self): if self._result_cache is None: - self._result_cache = list(self._iterable_class(self)) + self._result_cache = [elt for elt in self._iterable_class(self)] if self._prefetch_related_lookups and not self._prefetch_done: self._prefetch_related_objects() + @generate_unasynced() + async def _afetch_all(self): + if self._result_cache is None: + self._result_cache = [elt async for elt in self._iterable_class(self)] + if self._prefetch_related_lookups and not self._prefetch_done: + await self._aprefetch_related_objects() + def _next_is_sticky(self): """ Indicate that the next filter call and the one following that should @@ -2092,10 +2866,18 @@ def prefetch_related(self, *lookups): clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups return clone + @from_codegen def _prefetch_related_objects(self): prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups) self._prefetch_done = True + @generate_unasynced() + async def _aprefetch_related_objects(self): + await aprefetch_related_objects( + self._result_cache, *self._prefetch_related_lookups + ) + self._prefetch_done = True + def _clone(self): """Same as QuerySet._clone()""" c = self.__class__( @@ -2116,6 +2898,16 @@ def _fetch_all(self): if self._prefetch_related_lookups and not self._prefetch_done: self._prefetch_related_objects() + @from_codegen + def _fetch_then_len(self): + self._fetch_all() + return len(self._result_cache) + + @generate_unasynced() + async def _afetch_then_len(self): + await self._afetch_all() + return len(self._result_cache) + def __len__(self): self._fetch_all() return len(self._result_cache) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 6f90f11f1b2b..79915d6a3dad 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -3,6 +3,7 @@ import re from functools import partial from itertools import chain +from typing import AsyncGenerator from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError @@ -13,8 +14,8 @@ from django.db.models.lookups import Lookup from django.db.models.query_utils import select_related_descend from django.db.models.sql.constants import ( - CURSOR, GET_ITERATOR_CHUNK_SIZE, + CURSOR, MULTI, NO_RESULTS, ORDER_DIR, @@ -27,6 +28,7 @@ from django.utils.functional import cached_property from django.utils.hashable import make_hashable from django.utils.regex_helper import _lazy_re_compile +from django.utils.codegen import from_codegen, generate_unasynced, ASYNC_TRUTH_MARKER class PositionRef(Ref): @@ -752,6 +754,7 @@ def collect_replacements(expressions): result.extend(["ORDER BY", ", ".join(ordering_sqls)]) return result, params + @from_codegen def as_sql(self, with_limits=True, with_col_aliases=False): """ Create the SQL for this query. Return the SQL string and list of @@ -979,6 +982,234 @@ def as_sql(self, with_limits=True, with_col_aliases=False): # Finally do cleanup - get rid of the joins we created above. self.query.reset_refcounts(refcounts_before) + @generate_unasynced() + async def aas_sql(self, with_limits=True, with_col_aliases=False): + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + + If 'with_limits' is False, any limit/offset information is not included + in the query. + """ + refcounts_before = self.query.alias_refcount.copy() + try: + combinator = self.query.combinator + extra_select, order_by, group_by = self.pre_sql_setup( + with_col_aliases=with_col_aliases or bool(combinator), + ) + for_update_part = None + # Is a LIMIT/OFFSET clause needed? + with_limit_offset = with_limits and self.query.is_sliced + combinator = self.query.combinator + features = self.connection.features + if combinator: + if not getattr(features, "supports_select_{}".format(combinator)): + raise NotSupportedError( + "{} is not supported on this database backend.".format( + combinator + ) + ) + result, params = self.get_combinator_sql( + combinator, self.query.combinator_all + ) + elif self.qualify: + result, params = self.get_qualify_sql() + order_by = None + else: + distinct_fields, distinct_params = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' + # (see docstring of get_from_clause() for details). + from_, f_params = self.get_from_clause() + try: + where, w_params = ( + self.compile(self.where) if self.where is not None else ("", []) + ) + except EmptyResultSet: + if self.elide_empty: + raise + # Use a predicate that's always False. + where, w_params = "0 = 1", [] + except FullResultSet: + where, w_params = "", [] + try: + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + except FullResultSet: + having, h_params = "", [] + result = ["SELECT"] + params = [] + + if self.query.distinct: + distinct_result, distinct_params = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + ) + result += distinct_result + params += distinct_params + + out_cols = [] + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = "%s AS %s" % ( + s_sql, + self.connection.ops.quote_name(alias), + ) + params.extend(s_params) + out_cols.append(s_sql) + + result += [", ".join(out_cols)] + if from_: + result += ["FROM", *from_] + elif self.connection.features.bare_select_suffix: + result += [self.connection.features.bare_select_suffix] + params.extend(f_params) + + if self.query.select_for_update and features.has_select_for_update: + if ( + await self.connection.aget_autocommit() + # Don't raise an exception when database doesn't + # support transactions, as it's a noop. + and features.supports_transactions + ): + raise TransactionManagementError( + "select_for_update cannot be used outside of a transaction." + ) + + if ( + with_limit_offset + and not features.supports_select_for_update_with_limit + ): + raise NotSupportedError( + "LIMIT/OFFSET is not supported with " + "select_for_update on this database backend." + ) + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + of = self.query.select_for_update_of + no_key = self.query.select_for_no_key_update + # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the + # backend doesn't support it, raise NotSupportedError to + # prevent a possible deadlock. + if nowait and not features.has_select_for_update_nowait: + raise NotSupportedError( + "NOWAIT is not supported on this database backend." + ) + elif skip_locked and not features.has_select_for_update_skip_locked: + raise NotSupportedError( + "SKIP LOCKED is not supported on this database backend." + ) + elif of and not features.has_select_for_update_of: + raise NotSupportedError( + "FOR UPDATE OF is not supported on this database backend." + ) + elif no_key and not features.has_select_for_no_key_update: + raise NotSupportedError( + "FOR NO KEY UPDATE is not supported on this " + "database backend." + ) + for_update_part = self.connection.ops.for_update_sql( + nowait=nowait, + skip_locked=skip_locked, + of=self.get_select_for_update_of_arguments(), + no_key=no_key, + ) + + if for_update_part and features.for_update_after_from: + result.append(for_update_part) + + if where: + result.append("WHERE %s" % where) + params.extend(w_params) + + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError( + "annotate() + distinct(fields) is not implemented." + ) + order_by = order_by or self.connection.ops.force_no_ordering() + result.append("GROUP BY %s" % ", ".join(grouping)) + if self._meta_ordering: + order_by = None + if having: + if not grouping: + result.extend(self.connection.ops.force_group_by()) + result.append("HAVING %s" % having) + params.extend(h_params) + + if self.query.explain_info: + result.insert( + 0, + self.connection.ops.explain_query_prefix( + self.query.explain_info.format, + **self.query.explain_info.options, + ), + ) + + if order_by: + ordering = [] + for _, (o_sql, o_params, _) in order_by: + ordering.append(o_sql) + params.extend(o_params) + order_by_sql = "ORDER BY %s" % ", ".join(ordering) + if combinator and features.requires_compound_order_by_subquery: + result = ["SELECT * FROM (", *result, ")", order_by_sql] + else: + result.append(order_by_sql) + + if with_limit_offset: + result.append( + self.connection.ops.limit_offset_sql( + self.query.low_mark, self.query.high_mark + ) + ) + + if for_update_part and not features.for_update_after_from: + result.append(for_update_part) + + if self.query.subquery and extra_select: + # If the query is used as a subquery, the extra selects would + # result in more columns than the left-hand side expression is + # expecting. This can happen when a subquery uses a combination + # of order_by() and distinct(), forcing the ordering expressions + # to be selected as well. Wrap the query in another subquery + # to exclude extraneous selects. + sub_selects = [] + sub_params = [] + for index, (select, _, alias) in enumerate(self.select, start=1): + if alias: + sub_selects.append( + "%s.%s" + % ( + self.connection.ops.quote_name("subquery"), + self.connection.ops.quote_name(alias), + ) + ) + else: + select_clone = select.relabeled_clone( + {select.alias: "subquery"} + ) + subselect, subparams = select_clone.as_sql( + self, self.connection + ) + sub_selects.append(subselect) + sub_params.extend(subparams) + return "SELECT %s FROM (%s) subquery" % ( + ", ".join(sub_selects), + " ".join(result), + ), tuple(sub_params + params) + + return " ".join(result), tuple(params) + finally: + # Finally do cleanup - get rid of the joins we created above. + self.query.reset_refcounts(refcounts_before) + def get_default_columns( self, select_mask, start_alias=None, opts=None, from_parent=None ): @@ -1560,6 +1791,7 @@ def composite_fields_to_tuples(self, rows, expressions): yield row + @from_codegen def results_iter( self, results=None, @@ -1572,6 +1804,42 @@ def results_iter( results = self.execute_sql( MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size ) + else: + # XXX wrong + # this is forcing evaluation of athing way to early + # instead of being an actual iterable + if isinstance(results, AsyncGenerator): + results = [r for r in results] + fields = [s[0] for s in self.select[0 : self.col_count]] + converters = self.get_converters(fields) + rows = chain.from_iterable(results) + if converters: + rows = self.apply_converters(rows, converters) + if self.has_composite_fields(fields): + rows = self.composite_fields_to_tuples(rows, fields) + if tuple_expected: + rows = map(tuple, rows) + return rows + + @generate_unasynced() + async def aresults_iter( + self, + results=None, + tuple_expected=False, + chunked_fetch=False, + chunk_size=GET_ITERATOR_CHUNK_SIZE, + ): + """Return an iterator over the results from executing this query.""" + if results is None: + results = await self.aexecute_sql( + MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size + ) + else: + # XXX wrong + # this is forcing evaluation of athing way to early + # instead of being an actual iterable + if isinstance(results, AsyncGenerator): + results = [r async for r in results] fields = [s[0] for s in self.select[0 : self.col_count]] converters = self.get_converters(fields) rows = chain.from_iterable(results) @@ -1583,6 +1851,7 @@ def results_iter( rows = map(tuple, rows) return rows + @from_codegen def has_results(self): """ Backends (e.g. NoSQL) can override this in order to use optimized @@ -1590,6 +1859,15 @@ def has_results(self): """ return bool(self.execute_sql(SINGLE)) + @generate_unasynced() + async def ahas_results(self): + """ + Backends (e.g. NoSQL) can override this in order to use optimized + versions of "query has any results." + """ + return bool(await self.aexecute_sql(SINGLE)) + + @from_codegen def execute_sql( self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE ): @@ -1619,6 +1897,7 @@ def execute_sql( cursor = self.connection.chunked_cursor() else: cursor = self.connection.cursor() + try: cursor.execute(sql, params) except Exception: @@ -1631,10 +1910,9 @@ def execute_sql( return cursor.rowcount finally: cursor.close() - if result_type == CURSOR: - # Give the caller the cursor to process and close. + elif result_type == CURSOR: return cursor - if result_type == SINGLE: + elif result_type == SINGLE: try: val = cursor.fetchone() if val: @@ -1643,23 +1921,123 @@ def execute_sql( finally: # done with the cursor cursor.close() - if result_type == NO_RESULTS: + elif result_type == NO_RESULTS: cursor.close() return + elif result_type == ROW_COUNT: + try: + return cursor.rowcount + finally: + cursor.close() + else: + assert result_type == MULTI + result = cursor_iter( + cursor, + self.connection.features.empty_fetchmany_value, + self.col_count if self.has_extra_select else None, + chunk_size, + ) + if not chunked_fetch or not self.connection.features.can_use_chunked_reads: + # If we are using non-chunked reads, we return the same data + # structure as normally, but ensure it is all read into memory + # before going any further. Use chunked_fetch if requested, + # unless the database doesn't support it. + return [elt for elt in result] + return result + + @generate_unasynced() + async def aexecute_sql( + self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE + ): + """ + Run the query against the database and return the result(s). The + return value depends on the value of result_type. - result = cursor_iter( - cursor, - self.connection.features.empty_fetchmany_value, - self.col_count if self.has_extra_select else None, - chunk_size, - ) - if not chunked_fetch or not self.connection.features.can_use_chunked_reads: - # If we are using non-chunked reads, we return the same data - # structure as normally, but ensure it is all read into memory - # before going any further. Use chunked_fetch if requested, - # unless the database doesn't support it. - return list(result) - return result + When result_type is: + - MULTI: Retrieves all rows using fetchmany(). Wraps in an iterator for + chunked reads when supported. + - SINGLE: Retrieves a single row using fetchone(). + - ROW_COUNT: Retrieves the number of rows in the result. + - CURSOR: Runs the query, and returns the cursor object. It is the + caller's responsibility to close the cursor. + """ + result_type = result_type or NO_RESULTS + try: + sql, params = await self.aas_sql() + if not sql: + raise EmptyResultSet + except EmptyResultSet: + if result_type == MULTI: + return iter([]) + else: + return + # if "pg_sleep" in sql: + # raise ValueError("FOUND") + if ASYNC_TRUTH_MARKER: + if chunked_fetch: + cursor = await (await self.connection.achunked_cursor()).__aenter__() + else: + cursor = await self.connection.acursor().__aenter__() + else: + if chunked_fetch: + cursor = self.connection.chunked_cursor() + else: + cursor = self.connection.cursor() + + try: + await cursor.aexecute(sql, params) + except Exception: + # Might fail for server-side cursors (e.g. connection closed) + await cursor.aclose() + raise + + if result_type == ROW_COUNT: + try: + return cursor.rowcount + finally: + cursor.close() + elif result_type == CURSOR: + return cursor + elif result_type == SINGLE: + try: + val = await cursor.afetchone() + if val: + return val[0 : self.col_count] + return val + finally: + # done with the cursor + await cursor.aclose() + elif result_type == NO_RESULTS: + await cursor.aclose() + return + elif result_type == ROW_COUNT: + try: + return cursor.rowcount + finally: + await cursor.aclose() + else: + assert result_type == MULTI + if ASYNC_TRUTH_MARKER: + result = acursor_iter( + cursor, + self.connection.features.empty_fetchmany_value, + self.col_count if self.has_extra_select else None, + chunk_size, + ) + else: + result = cursor_iter( + cursor, + self.connection.features.empty_fetchmany_value, + self.col_count if self.has_extra_select else None, + chunk_size, + ) + if not chunked_fetch or not self.connection.features.can_use_chunked_reads: + # If we are using non-chunked reads, we return the same data + # structure as normally, but ensure it is all read into memory + # before going any further. Use chunked_fetch if requested, + # unless the database doesn't support it. + return [elt async for elt in result] + return result def as_subquery_condition(self, alias, columns, compiler): qn = compiler.quote_name_unless_alias @@ -1674,6 +2052,7 @@ def as_subquery_condition(self, alias, columns, compiler): sql, params = query.as_sql(compiler, self.connection) return "EXISTS %s" % sql, params + @from_codegen def explain_query(self): result = list(self.execute_sql()) # Some backends return 1 item tuples with strings, and others return @@ -1687,6 +2066,20 @@ def explain_query(self): else: yield value + @generate_unasynced() + async def aexplain_query(self): + result = list(await self.aexecute_sql()) + # Some backends return 1 item tuples with strings, and others return + # tuples with integers and strings. Flatten them out into strings. + format_ = self.query.explain_info.format + output_formatter = json.dumps if format_ and format_.lower() == "json" else str + for row in result: + for value in row: + if not isinstance(value, str): + yield " ".join([output_formatter(c) for c in value]) + else: + yield value + class SQLInsertCompiler(SQLCompiler): returning_fields = None @@ -1801,6 +2194,7 @@ def assemble_as_sql(self, fields, value_rows): return placeholder_rows, param_rows + @from_codegen def as_sql(self): # We don't need quote_name_unless_alias() here, since these are all # going to be column names (so we can avoid the extra overhead). @@ -1881,6 +2275,88 @@ def as_sql(self): for p, vals in zip(placeholder_rows, param_rows) ] + @generate_unasynced() + async def aas_sql(self): + # We don't need quote_name_unless_alias() here, since these are all + # going to be column names (so we can avoid the extra overhead). + qn = self.connection.ops.quote_name + opts = self.query.get_meta() + insert_statement = self.connection.ops.insert_statement( + on_conflict=self.query.on_conflict, + ) + result = ["%s %s" % (insert_statement, qn(opts.db_table))] + fields = self.query.fields or [opts.pk] + result.append("(%s)" % ", ".join(qn(f.column) for f in fields)) + + if self.query.fields: + value_rows = [ + [ + self.prepare_value(field, self.pre_save_val(field, obj)) + for field in fields + ] + for obj in self.query.objs + ] + else: + # An empty object. + value_rows = [ + [self.connection.ops.pk_default_value()] for _ in self.query.objs + ] + fields = [None] + + # Currently the backends just accept values when generating bulk + # queries and generate their own placeholders. Doing that isn't + # necessary and it should be possible to use placeholders and + # expressions in bulk inserts too. + can_bulk = ( + not self.returning_fields and self.connection.features.has_bulk_insert + ) + + placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) + + on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql( + fields, + self.query.on_conflict, + (f.column for f in self.query.update_fields), + (f.column for f in self.query.unique_fields), + ) + if ( + self.returning_fields + and self.connection.features.can_return_columns_from_insert + ): + if self.connection.features.can_return_rows_from_bulk_insert: + result.append( + self.connection.ops.bulk_insert_sql(fields, placeholder_rows) + ) + params = param_rows + else: + result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) + params = [param_rows[0]] + if on_conflict_suffix_sql: + result.append(on_conflict_suffix_sql) + # Skip empty r_sql to allow subclasses to customize behavior for + # 3rd party backends. Refs #19096. + r_sql, self.returning_params = self.connection.ops.return_insert_columns( + self.returning_fields + ) + if r_sql: + result.append(r_sql) + params += [self.returning_params] + return [(" ".join(result), tuple(chain.from_iterable(params)))] + + if can_bulk: + result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) + if on_conflict_suffix_sql: + result.append(on_conflict_suffix_sql) + return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] + else: + if on_conflict_suffix_sql: + result.append(on_conflict_suffix_sql) + return [ + (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) + for p, vals in zip(placeholder_rows, param_rows) + ] + + @from_codegen def execute_sql(self, returning_fields=None): assert not ( returning_fields @@ -1923,6 +2399,60 @@ def execute_sql(self, returning_fields=None): ), ) ] + + else: + # Backend doesn't support returning fields and no auto-field + # that can be retrieved from `last_insert_id` was specified. + return [] + converters = self.get_converters(cols) + if converters: + rows = self.apply_converters(rows, converters) + return list(rows) + + @generate_unasynced() + async def aexecute_sql(self, returning_fields=None): + assert not ( + returning_fields + and len(self.query.objs) != 1 + and not self.connection.features.can_return_rows_from_bulk_insert + ) + opts = self.query.get_meta() + self.returning_fields = returning_fields + cols = [] + async with self.connection.acursor() as cursor: + for sql, params in self.as_sql(): + await cursor.aexecute(sql, params) + if not self.returning_fields: + return [] + if ( + self.connection.features.can_return_rows_from_bulk_insert + and len(self.query.objs) > 1 + ): + rows = await self.connection.ops.afetch_returned_insert_rows(cursor) + cols = [field.get_col(opts.db_table) for field in self.returning_fields] + elif self.connection.features.can_return_columns_from_insert: + assert len(self.query.objs) == 1 + rows = [ + await self.connection.ops.afetch_returned_insert_columns( + cursor, + self.returning_params, + ) + ] + cols = [field.get_col(opts.db_table) for field in self.returning_fields] + elif returning_fields and isinstance( + returning_field := returning_fields[0], AutoField + ): + cols = [returning_field.get_col(opts.db_table)] + rows = [ + ( + self.connection.ops.last_insert_id( + cursor, + opts.db_table, + returning_field.column, + ), + ) + ] + else: # Backend doesn't support returning fields and no auto-field # that can be retrieved from `last_insert_id` was specified. @@ -1968,6 +2498,7 @@ def _as_sql(self, query): return delete, () return f"{delete} WHERE {where}", tuple(params) + @from_codegen def as_sql(self): """ Create the SQL for this query. Return the SQL string and list of @@ -1992,8 +2523,35 @@ def as_sql(self): outerq.add_filter("pk__in", innerq) return self._as_sql(outerq) + @generate_unasynced() + async def aas_sql(self): + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + """ + if self.single_alias and ( + self.connection.features.delete_can_self_reference_subquery + or not self.contains_self_reference_subquery + ): + return self._as_sql(self.query) + innerq = self.query.clone() + innerq.__class__ = Query + innerq.clear_select_clause() + pk = self.query.model._meta.pk + innerq.select = [pk.get_col(self.query.get_initial_alias())] + outerq = Query(self.query.model) + if not self.connection.features.update_can_self_select: + # Force the materialization of the inner query to allow reference + # to the target table on MySQL. + sql, params = innerq.get_compiler(connection=self.connection).as_sql() + innerq = RawSQL("SELECT * FROM (%s) subquery" % sql, params) + outerq.add_filter("pk__in", innerq) + return self._as_sql(outerq) + class SQLUpdateCompiler(SQLCompiler): + + @from_codegen def as_sql(self): """ Create the SQL for this query. Return the SQL string and list of @@ -2058,6 +2616,72 @@ def as_sql(self): result.append("WHERE %s" % where) return " ".join(result), tuple(update_params + params) + @generate_unasynced() + async def aas_sql(self): + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + """ + self.pre_sql_setup() + if not self.query.values: + return "", () + qn = self.quote_name_unless_alias + values, update_params = [], [] + for field, model, val in self.query.values: + if hasattr(val, "resolve_expression"): + val = val.resolve_expression( + self.query, allow_joins=False, for_save=True + ) + if val.contains_aggregate: + raise FieldError( + "Aggregate functions are not allowed in this query " + "(%s=%r)." % (field.name, val) + ) + if val.contains_over_clause: + raise FieldError( + "Window expressions are not allowed in this query " + "(%s=%r)." % (field.name, val) + ) + elif hasattr(val, "prepare_database_save"): + if field.remote_field: + val = val.prepare_database_save(field) + else: + raise TypeError( + "Tried to update field %s with a model instance, %r. " + "Use a value compatible with %s." + % (field, val, field.__class__.__name__) + ) + val = field.get_db_prep_save(val, connection=self.connection) + + # Getting the placeholder for the field. + if hasattr(field, "get_placeholder"): + placeholder = field.get_placeholder(val, self, self.connection) + else: + placeholder = "%s" + name = field.column + if hasattr(val, "as_sql"): + sql, params = self.compile(val) + values.append("%s = %s" % (qn(name), placeholder % sql)) + update_params.extend(params) + elif val is not None: + values.append("%s = %s" % (qn(name), placeholder)) + update_params.append(val) + else: + values.append("%s = NULL" % qn(name)) + table = self.query.base_table + result = [ + "UPDATE %s SET" % qn(table), + ", ".join(values), + ] + try: + where, params = self.compile(self.query.where) + except FullResultSet: + params = [] + else: + result.append("WHERE %s" % where) + return " ".join(result), tuple(update_params + params) + + @from_codegen def execute_sql(self, result_type): """ Execute the specified update. Return the number of rows affected by @@ -2079,7 +2703,31 @@ def execute_sql(self, result_type): is_empty = False return row_count - def pre_sql_setup(self): + @generate_unasynced() + async def aexecute_sql(self, result_type): + """ + Execute the specified update. Return the number of rows affected by + the primary update query. The "primary update query" is the first + non-empty query that is executed. Row counts for any subsequent, + related queries are not available. + """ + row_count = await super().aexecute_sql(result_type) + is_empty = row_count is None + row_count = row_count or 0 + + for query in self.query.get_related_updates(): + # If the result_type is NO_RESULTS then the aux_row_count is None. + aux_row_count = await query.get_compiler(self.using).aexecute_sql( + result_type + ) + if is_empty and aux_row_count: + # Returns the row count for any related updates as the number of + # rows updated. + row_count = aux_row_count + is_empty = False + return row_count + + def pre_sql_setup(self, with_col_aliases=False): """ If the update depends on results from other tables, munge the "where" conditions to match the format required for (portable) SQL updates. @@ -2116,7 +2764,7 @@ def pre_sql_setup(self): related_ids_index.append((related, len(fields))) fields.append(related._meta.pk.name) query.add_fields(fields) - super().pre_sql_setup() + super().pre_sql_setup(with_col_aliases=with_col_aliases) is_composite_pk = meta.is_composite_pk must_pre_select = ( @@ -2146,6 +2794,8 @@ def pre_sql_setup(self): class SQLAggregateCompiler(SQLCompiler): + + @from_codegen def as_sql(self): """ Create the SQL for this query. Return the SQL string and list of @@ -2169,14 +2819,58 @@ def as_sql(self): params += inner_query_params return sql, params + @generate_unasynced() + async def aas_sql(self): + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + """ + sql, params = [], [] + for annotation in self.query.annotation_select.values(): + ann_sql, ann_params = self.compile(annotation) + ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params) + sql.append(ann_sql) + params.extend(ann_params) + self.col_count = len(self.query.annotation_select) + sql = ", ".join(sql) + params = tuple(params) + + inner_query_sql, inner_query_params = self.query.inner_query.get_compiler( + self.using, + elide_empty=self.elide_empty, + ).as_sql(with_col_aliases=True) + sql = "SELECT %s FROM (%s) subquery" % (sql, inner_query_sql) + params += inner_query_params + return sql, params + +@from_codegen def cursor_iter(cursor, sentinel, col_count, itersize): """ Yield blocks of rows from a cursor and ensure the cursor is closed when done. """ try: - for rows in iter((lambda: cursor.fetchmany(itersize)), sentinel): + while True: + rows = cursor.fetchmany(itersize) + if rows == sentinel: + break yield rows if col_count is None else [r[:col_count] for r in rows] finally: cursor.close() + + +@generate_unasynced() +async def acursor_iter(cursor, sentinel, col_count, itersize): + """ + Yield blocks of rows from a cursor and ensure the cursor is closed when + done. + """ + try: + while True: + rows = await cursor.afetchmany(itersize) + if rows == sentinel: + break + yield rows if col_count is None else [r[:col_count] for r in rows] + finally: + await cursor.aclose() diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 709405b0dfb8..ce4a1bd2eff7 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -9,8 +9,12 @@ # Namedtuples for sql.* internal use. # How many results to expect from a cursor.execute call +# multiple rows are expected MULTI = "multi" +# a single row is expected SINGLE = "single" +# instead of returning the rows, return the row count +CURSOR = "cursor" NO_RESULTS = "no results" # Rather than returning results, returns: CURSOR = "cursor" diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 6fbf854e67f0..038cb701dcc5 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -17,7 +17,12 @@ from string import ascii_uppercase from django.core.exceptions import FieldDoesNotExist, FieldError -from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections +from django.db import ( + DEFAULT_DB_ALIAS, + NotSupportedError, + connections, + async_connections, +) from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import ( @@ -42,6 +47,7 @@ from django.db.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE from django.db.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin from django.db.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode +from django.utils.codegen import ASYNC_TRUTH_MARKER, from_codegen, generate_unasynced from django.utils.functional import cached_property from django.utils.regex_helper import _lazy_re_compile from django.utils.tree import Node @@ -355,11 +361,24 @@ def __deepcopy__(self, memo): memo[id(self)] = result return result - def get_compiler(self, using=None, connection=None, elide_empty=True): + def get_compiler( + self, using=None, connection=None, elide_empty=True, raise_on_miss=False + ): + if using is None and connection is None: + raise ValueError("Need either using or connection") + if using: + connection = connections.get_item(using, raise_on_miss=raise_on_miss) + return connection.ops.compiler(self.compiler)( + self, connection, using, elide_empty + ) + + def aget_compiler( + self, using=None, connection=None, elide_empty=True, raise_on_miss=True + ): if using is None and connection is None: raise ValueError("Need either using or connection") if using: - connection = connections[using] + connection = async_connections.get_connection(using) return connection.ops.compiler(self.compiler)( self, connection, using, elide_empty ) @@ -443,6 +462,7 @@ def _get_col(self, target, field, alias): alias = None return target.get_col(alias, field) + @from_codegen def get_aggregation(self, using, aggregate_exprs): """ Return the dictionary with the values of the existing aggregations. @@ -636,6 +656,204 @@ def get_aggregation(self, using, aggregate_exprs): return dict(zip(outer_query.annotation_select, result)) + @generate_unasynced() + async def aget_aggregation(self, using, aggregate_exprs): + """ + Return the dictionary with the values of the existing aggregations. + """ + if not aggregate_exprs: + return {} + # Store annotation mask prior to temporarily adding aggregations for + # resolving purpose to facilitate their subsequent removal. + refs_subquery = False + refs_window = False + replacements = {} + annotation_select_mask = self.annotation_select_mask + for alias, aggregate_expr in aggregate_exprs.items(): + self.check_alias(alias) + aggregate = aggregate_expr.resolve_expression( + self, allow_joins=True, reuse=None, summarize=True + ) + if not aggregate.contains_aggregate: + raise TypeError("%s is not an aggregate expression" % alias) + # Temporarily add aggregate to annotations to allow remaining + # members of `aggregates` to resolve against each others. + self.append_annotation_mask([alias]) + aggregate_refs = aggregate.get_refs() + refs_subquery |= any( + getattr(self.annotations[ref], "contains_subquery", False) + for ref in aggregate_refs + ) + refs_window |= any( + getattr(self.annotations[ref], "contains_over_clause", True) + for ref in aggregate_refs + ) + aggregate = aggregate.replace_expressions(replacements) + self.annotations[alias] = aggregate + replacements[Ref(alias, aggregate)] = aggregate + # Stash resolved aggregates now that they have been allowed to resolve + # against each other. + aggregates = {alias: self.annotations.pop(alias) for alias in aggregate_exprs} + self.set_annotation_mask(annotation_select_mask) + # Existing usage of aggregation can be determined by the presence of + # selected aggregates but also by filters against aliased aggregates. + _, having, qualify = self.where.split_having_qualify() + has_existing_aggregation = ( + any( + getattr(annotation, "contains_aggregate", True) + for annotation in self.annotations.values() + ) + or having + ) + set_returning_annotations = { + alias + for alias, annotation in self.annotation_select.items() + if getattr(annotation, "set_returning", False) + } + # Decide if we need to use a subquery. + # + # Existing aggregations would cause incorrect results as + # get_aggregation() must produce just one result and thus must not use + # GROUP BY. + # + # If the query has limit or distinct, or uses set operations, then + # those operations must be done in a subquery so that the query + # aggregates on the limit and/or distinct results instead of applying + # the distinct and limit after the aggregation. + if ( + isinstance(self.group_by, tuple) + or self.is_sliced + or has_existing_aggregation + or refs_subquery + or refs_window + or qualify + or self.distinct + or self.combinator + or set_returning_annotations + ): + from django.db.models.sql.subqueries import AggregateQuery + + inner_query = self.clone() + inner_query.subquery = True + outer_query = AggregateQuery(self.model, inner_query) + inner_query.select_for_update = False + inner_query.select_related = False + inner_query.set_annotation_mask(self.annotation_select) + # Queries with distinct_fields need ordering and when a limit is + # applied we must take the slice from the ordered query. Otherwise + # no need for ordering. + inner_query.clear_ordering(force=False) + if not inner_query.distinct: + # If the inner query uses default select and it has some + # aggregate annotations, then we must make sure the inner + # query is grouped by the main model's primary key. However, + # clearing the select clause can alter results if distinct is + # used. + if inner_query.default_cols and has_existing_aggregation: + inner_query.group_by = ( + self.model._meta.pk.get_col(inner_query.get_initial_alias()), + ) + inner_query.default_cols = False + if not qualify and not self.combinator: + # Mask existing annotations that are not referenced by + # aggregates to be pushed to the outer query unless + # filtering against window functions or if the query is + # combined as both would require complex realiasing logic. + annotation_mask = set() + if isinstance(self.group_by, tuple): + for expr in self.group_by: + annotation_mask |= expr.get_refs() + for aggregate in aggregates.values(): + annotation_mask |= aggregate.get_refs() + # Avoid eliding expressions that might have an incidence on + # the implicit grouping logic. + for annotation_alias, annotation in self.annotation_select.items(): + if annotation.get_group_by_cols(): + annotation_mask.add(annotation_alias) + inner_query.set_annotation_mask(annotation_mask) + # Annotations that possibly return multiple rows cannot + # be masked as they might have an incidence on the query. + annotation_mask |= set_returning_annotations + + # Add aggregates to the outer AggregateQuery. This requires making + # sure all columns referenced by the aggregates are selected in the + # inner query. It is achieved by retrieving all column references + # by the aggregates, explicitly selecting them in the inner query, + # and making sure the aggregates are repointed to them. + col_refs = {} + for alias, aggregate in aggregates.items(): + replacements = {} + for col in self._gen_cols([aggregate], resolve_refs=False): + if not (col_ref := col_refs.get(col)): + index = len(col_refs) + 1 + col_alias = f"__col{index}" + col_ref = Ref(col_alias, col) + col_refs[col] = col_ref + inner_query.add_annotation(col, col_alias) + replacements[col] = col_ref + outer_query.annotations[alias] = aggregate.replace_expressions( + replacements + ) + if ( + inner_query.select == () + and not inner_query.default_cols + and not inner_query.annotation_select_mask + ): + # In case of Model.objects[0:3].count(), there would be no + # field selected in the inner query, yet we must use a subquery. + # So, make sure at least one field is selected. + inner_query.select = ( + self.model._meta.pk.get_col(inner_query.get_initial_alias()), + ) + else: + outer_query = self + self.select = () + self.selected = None + self.default_cols = False + self.extra = {} + if self.annotations: + # Inline reference to existing annotations and mask them as + # they are unnecessary given only the summarized aggregations + # are requested. + replacements = { + Ref(alias, annotation): annotation + for alias, annotation in self.annotations.items() + } + self.annotations = { + alias: aggregate.replace_expressions(replacements) + for alias, aggregate in aggregates.items() + } + else: + self.annotations = aggregates + self.set_annotation_mask(aggregates) + + empty_set_result = [ + expression.empty_result_set_value + for expression in outer_query.annotation_select.values() + ] + elide_empty = not any(result is NotImplemented for result in empty_set_result) + outer_query.clear_ordering(force=True) + outer_query.clear_limits() + outer_query.select_for_update = False + outer_query.select_related = False + if ASYNC_TRUTH_MARKER: + compiler = outer_query.aget_compiler(using, elide_empty=elide_empty) + else: + compiler = outer_query.get_compiler(using, elide_empty=elide_empty) + result = await compiler.aexecute_sql(SINGLE) + if result is None: + result = empty_set_result + else: + cols = outer_query.annotation_select.values() + converters = compiler.get_converters(cols) + rows = compiler.apply_converters((result,), converters) + if compiler.has_composite_fields(cols): + rows = compiler.composite_fields_to_tuples(rows, cols) + result = next(rows) + + return dict(zip(outer_query.annotation_select, result)) + + @from_codegen def get_count(self, using): """ Perform a COUNT() query using the current filter constraints. @@ -643,6 +861,14 @@ def get_count(self, using): obj = self.clone() return obj.get_aggregation(using, {"__count": Count("*")})["__count"] + @generate_unasynced() + async def aget_count(self, using): + """ + Perform a COUNT() query using the current filter constraints. + """ + obj = self.clone() + return (await obj.aget_aggregation(using, {"__count": Count("*")}))["__count"] + def has_filters(self): return self.where @@ -668,11 +894,22 @@ def exists(self, limit=True): q.add_annotation(Value(1), "a") return q + @from_codegen def has_results(self, using): q = self.exists() compiler = q.get_compiler(using=using) return compiler.has_results() + @generate_unasynced() + async def ahas_results(self, using): + q = self.exists() + if ASYNC_TRUTH_MARKER: + compiler = q.aget_compiler(using=using) + else: + compiler = q.get_compiler(using=using) + return await compiler.ahas_results() + + @from_codegen def explain(self, using, format=None, **options): q = self.clone() for option_name in options: @@ -685,6 +922,22 @@ def explain(self, using, format=None, **options): compiler = q.get_compiler(using=using) return "\n".join(compiler.explain_query()) + @generate_unasynced() + async def aexplain(self, using, format=None, **options): + q = self.clone() + for option_name in options: + if ( + not EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name) + or "--" in option_name + ): + raise ValueError(f"Invalid option name: {option_name!r}.") + q.explain_info = ExplainInfo(format, options) + if ASYNC_TRUTH_MARKER: + compiler = q.aget_compiler(using=using) + else: + compiler = q.get_compiler(using=using) + return "\n".join(await compiler.aexplain_query()) + def combine(self, rhs, connector): """ Merge the 'rhs' query into the current one (with any 'rhs' effects diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index b2810c8413b5..52ff4644cdd6 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -9,6 +9,7 @@ ROW_COUNT, ) from django.db.models.sql.query import Query +from django.utils.codegen import from_codegen, generate_unasynced __all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"] @@ -18,11 +19,21 @@ class DeleteQuery(Query): compiler = "SQLDeleteCompiler" + @from_codegen def do_query(self, table, where, using): self.alias_map = {table: self.alias_map[table]} self.where = where + return self.get_compiler(using).execute_sql(ROW_COUNT) + @generate_unasynced() + async def ado_query(self, table, where, using): + self.alias_map = {table: self.alias_map[table]} + self.where = where + + return await self.aget_compiler(using).aexecute_sql(ROW_COUNT) + + @from_codegen def delete_batch(self, pk_list, using): """ Set up and execute delete queries for all the objects in pk_list. @@ -44,6 +55,28 @@ def delete_batch(self, pk_list, using): ) return num_deleted + @generate_unasynced() + async def adelete_batch(self, pk_list, using): + """ + Set up and execute delete queries for all the objects in pk_list. + + More than one physical query may be executed if there are a + lot of values in pk_list. + """ + # number of objects deleted + num_deleted = 0 + field = self.get_meta().pk + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + self.clear_where() + self.add_filter( + f"{field.attname}__in", + pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE], + ) + num_deleted += await self.ado_query( + self.get_meta().db_table, self.where, using=using + ) + return num_deleted + class UpdateQuery(Query): """An UPDATE SQL query.""" diff --git a/django/db/transaction.py b/django/db/transaction.py index 0c2eee8e7364..c2b93e8300e1 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -1,4 +1,9 @@ -from contextlib import ContextDecorator, contextmanager +import asyncio +from contextlib import ContextDecorator, asynccontextmanager, contextmanager +import contextvars +import weakref + +from asgiref.sync import sync_to_async from django.db import ( DEFAULT_DB_ALIAS, @@ -6,7 +11,10 @@ Error, ProgrammingError, connections, + async_connections, + should_use_sync_fallback, ) +from django.utils.codegen import ASYNC_TRUTH_MARKER, generate_unasynced class TransactionManagementError(ProgrammingError): @@ -25,6 +33,12 @@ def get_connection(using=None): return connections[using] +async def aget_connection(using=None): + if using is None: + using = DEFAULT_DB_ALIAS + return async_connections.get_connection(using) + + def get_autocommit(using=None): """Get the autocommit status of the connection.""" return get_connection(using).get_autocommit() @@ -97,7 +111,35 @@ def set_rollback(rollback, using=None): return get_connection(using).set_rollback(rollback) -@contextmanager +class MarkForRollbackOnError: + def __init__(self, using): + self.using = using + + def __enter__(self): + return self + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_val is not None: + connection = await aget_connection(self.using) + if connection.in_atomic_block: + connection.needs_rollback = True + connection.rollback_exc = exc_val + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_val is not None: + connection = get_connection(self.using) + if connection.in_atomic_block: + connection.needs_rollback = True + connection.rollback_exc = exc_val + + +def amark_for_rollback_on_error(using=None): + return MarkForRollbackOnError(using=using) + + def mark_for_rollback_on_error(using=None): """ Internal low-level utility to mark a transaction as "needs rollback" when @@ -116,14 +158,7 @@ def mark_for_rollback_on_error(using=None): but it uses low-level utilities to avoid performance overhead. """ - try: - yield - except Exception as exc: - connection = get_connection(using) - if connection.in_atomic_block: - connection.needs_rollback = True - connection.rollback_exc = exc - raise + return MarkForRollbackOnError(using=using) def on_commit(func, using=None, robust=False): @@ -179,7 +214,25 @@ def __init__(self, using, savepoint, durable): self.durable = durable self._from_testcase = False + # tracking how many atomic transactions I have done + _atomic_depth_ctx: dict[str, contextvars.ContextVar] = {} + + def atomic_depth_var(self, using): + if using is None: + using = DEFAULT_DB_ALIAS + # XXX race? + if using not in self._atomic_depth_ctx: + # XXX awkward context var + self._atomic_depth_ctx[using] = contextvars.ContextVar(using, default=0) + return self._atomic_depth_ctx[using] + + def current_atomic_depth(self, using): + return self.atomic_depth_var(using).get() + def __enter__(self): + + current_depth = self.atomic_depth_var(self.using) + current_depth.set(current_depth.get() + 1) connection = get_connection(self.using) if ( @@ -221,7 +274,69 @@ def __enter__(self): if connection.in_atomic_block: connection.atomic_blocks.append(self) + atxn_locks = weakref.WeakKeyDictionary() + + def get_atxn_lock(self, connection) -> asyncio.Lock: + lock = self.atxn_locks.get(connection, None) + if lock is None: + lock = self.atxn_locks[connection] = asyncio.Lock() + return lock + + # need to figure out how to generate __enter__ from __aenter__ + # @generate_unasynced() + async def __aenter__(self): + + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.__enter__)() + current_depth = self.atomic_depth_var(self.using) + current_depth.set(current_depth.get() + 1) + connection = await aget_connection(self.using) + + if ( + self.durable + and connection.atomic_blocks + and not connection.atomic_blocks[-1]._from_testcase + ): + raise RuntimeError( + "A durable atomic block cannot be nested within another " + "atomic block." + ) + + # XXX race + async with self.get_atxn_lock(connection): + if not connection.in_atomic_block: + # Reset state when entering an outermost atomic block. + connection.commit_on_exit = True + connection.needs_rollback = False + if not (await connection.aget_autocommit()): + # Pretend we're already in an atomic block to bypass the code + # that disables autocommit to enter a transaction, and make a + # note to deal with this case in __exit__. + connection.in_atomic_block = True + connection.commit_on_exit = False + + if connection.in_atomic_block: + # We're already in a transaction; create a savepoint, unless we + # were told not to or we're already waiting for a rollback. The + # second condition avoids creating useless savepoints and prevents + # overwriting needs_rollback until the rollback is performed. + if self.savepoint and not connection.needs_rollback: + sid = await connection.asavepoint() + connection.savepoint_ids.append(sid) + else: + connection.savepoint_ids.append(None) + else: + await connection.aset_autocommit( + False, force_begin_transaction_with_broken_autocommit=True + ) + connection.in_atomic_block = True + + if connection.in_atomic_block: + connection.atomic_blocks.append(self) + def __exit__(self, exc_type, exc_value, traceback): + current_depth = self.atomic_depth_var(self.using) + current_depth.set(current_depth.get() - 1) connection = get_connection(self.using) if connection.in_atomic_block: @@ -312,6 +427,103 @@ def __exit__(self, exc_type, exc_value, traceback): else: connection.in_atomic_block = False + # XXX try to get this working through generation as well + async def __aexit__(self, exc_type, exc_value, traceback): + if should_use_sync_fallback(ASYNC_TRUTH_MARKER): + return await sync_to_async(self.__exit__)(exc_type, exc_value, traceback) + current_depth = self.atomic_depth_var(self.using) + current_depth.set(current_depth.get() - 1) + connection = await aget_connection(self.using) + + async with self.get_atxn_lock(connection): + if connection.in_atomic_block: + connection.atomic_blocks.pop() + + if connection.savepoint_ids: + sid = connection.savepoint_ids.pop() + else: + # Prematurely unset this flag to allow using commit or rollback. + connection.in_atomic_block = False + + try: + if connection.closed_in_transaction: + # The database will perform a rollback by itself. + # Wait until we exit the outermost block. + pass + + elif exc_type is None and not connection.needs_rollback: + if connection.in_atomic_block: + # Release savepoint if there is one + if sid is not None: + try: + await connection.asavepoint_commit(sid) + except DatabaseError: + try: + await connection.asavepoint_rollback(sid) + # The savepoint won't be reused. Release it to + # minimize overhead for the database server. + await connection.asavepoint_commit(sid) + except Error: + # If rolling back to a savepoint fails, mark for + # rollback at a higher level and avoid shadowing + # the original exception. + connection.needs_rollback = True + raise + else: + # Commit transaction + try: + await connection.acommit() + except DatabaseError: + try: + await connection.arollback() + except Error: + # An error during rollback means that something + # went wrong with the connection. Drop it. + await connection.aclose() + raise + else: + # This flag will be set to True again if there isn't a savepoint + # allowing to perform the rollback at this level. + connection.needs_rollback = False + if connection.in_atomic_block: + # Roll back to savepoint if there is one, mark for rollback + # otherwise. + if sid is None: + connection.needs_rollback = True + else: + try: + await connection.asavepoint_rollback(sid) + # The savepoint won't be reused. Release it to + # minimize overhead for the database server. + await connection.asavepoint_commit(sid) + except Error: + # If rolling back to a savepoint fails, mark for + # rollback at a higher level and avoid shadowing + # the original exception. + connection.needs_rollback = True + else: + # Roll back transaction + try: + await connection.arollback() + except Error: + # An error during rollback means that something + # went wrong with the connection. Drop it. + await connection.aclose() + + finally: + # Outermost block exit when autocommit was enabled. + if not connection.in_atomic_block: + if connection.closed_in_transaction: + connection.connection = None + else: + connection.set_autocommit(True) + # Outermost block exit when autocommit was disabled. + elif not connection.savepoint_ids and not connection.commit_on_exit: + if connection.closed_in_transaction: + connection.connection = None + else: + connection.in_atomic_block = False + def atomic(using=None, savepoint=True, durable=False): # Bare decorator: @atomic -- although the first argument is called diff --git a/django/db/utils.py b/django/db/utils.py index e45f1db249ca..4f0f5e032b65 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -1,6 +1,9 @@ +import os import pkgutil from importlib import import_module +from asgiref.local import Local + from django.conf import settings from django.core.exceptions import ImproperlyConfigured @@ -144,6 +147,10 @@ class ConnectionHandler(BaseConnectionHandler): # after async contexts, though, so we don't allow that if we can help it. thread_critical = True + # a reference to an async connection handler, to be used for building + # proper proxying + async_connections: "AsyncConnectionHandler" + def configure_settings(self, databases): databases = super().configure_settings(databases) if databases == {}: @@ -194,6 +201,105 @@ def create_connection(self, alias): return backend.DatabaseWrapper(db, alias) +class AsyncAlias: + """ + A Context-aware list of connections. + """ + + def __init__(self) -> None: + self._connections = Local() + setattr(self._connections, "_stack", []) + + @property + def connections(self): + return getattr(self._connections, "_stack", []) + + def __len__(self): + return len(self.connections) + + def __iter__(self): + return iter(self.connections) + + def __str__(self): + return ", ".join([str(id(conn)) for conn in self.connections]) + + def __repr__(self): + return f"<{self.__class__.__name__}: {self}>" + + def add_connection(self, connection): + setattr(self._connections, "_stack", self.connections + [connection]) + + def pop(self): + conns = self.connections + conns.pop() + setattr(self._connections, "_stack", conns) + + +class AsyncConnectionHandler: + """ + Context-aware class to store async connections, mapped by alias name. + """ + + LOG_HITS = False + + _from_testcase = False + + # a reference to a sync connection handler, to be used for building + # proper proxying + sync_connections: ConnectionHandler + + def __init__(self) -> None: + self._aliases = Local() + self._connection_count = Local() + setattr(self._connection_count, "value", 0) + + def __getitem__(self, alias): + if self.LOG_HITS: + print(f"ACH.__getitem__[{alias}]") + try: + async_alias = getattr(self._aliases, alias) + except AttributeError: + if self.LOG_HITS: + print("CACHE MISS") + async_alias = AsyncAlias() + setattr(self._aliases, alias, async_alias) + else: + if self.LOG_HITS: + print("CACHE HIT") + return async_alias + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self}>" + + @property + def count(self): + return getattr(self._connection_count, "value", 0) + + @property + def empty(self): + return self.count == 0 + + def add_connection(self, using, connection): + if "QL" in os.environ: + print(f"add_connection {using=}") + self[using].add_connection(connection) + setattr(self._connection_count, "value", self.count + 1) + + def pop_connection(self, using): + if "QL" in os.environ: + print(f"pop_connection {using=}") + self[using].connections.pop() + setattr(self._connection_count, "value", self.count - 1) + + def get_connection(self, using): + alias = self[using] + if len(alias.connections) == 0: + raise ConnectionDoesNotExist( + f"There are no connections using the '{using}' alias." + ) + return alias.connections[-1] + + class ConnectionRouter: def __init__(self, routers=None): """ diff --git a/django/test/runner.py b/django/test/runner.py index c8bb16e7b377..7a95421a6f3e 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -656,12 +656,53 @@ def shuffle(self, items, key): return [hashes[hashed] for hashed in sorted(hashes)] +class SuccessTrackingTextTestResult(unittest.TextTestResult): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.successes = [] + + def addSuccess(self, test): + super().addSuccess(test) + self.successes.append(test) + + +class SuccessTrackingTextTestRunner(unittest.TextTestRunner): + resultclass = SuccessTrackingTextTestResult + + +class PDBDebugResult(SuccessTrackingTextTestResult): + """ + Custom result class that triggers a PDB session when an error or failure + occurs. + """ + + def addError(self, test, err): + super().addError(test, err) + self.debug(err) + + def addFailure(self, test, err): + super().addFailure(test, err) + self.debug(err) + + def addSubTest(self, test, subtest, err): + if err is not None: + self.debug(err) + super().addSubTest(test, subtest, err) + + def debug(self, error): + self._restoreStdout() + self.buffer = False + exc_type, exc_value, traceback = error + print("\nOpening PDB: %r" % exc_value) + pdb.post_mortem(traceback) + + class DiscoverRunner: - """A Django test runner that uses unittest2 test discovery.""" + """A Django tese runner that uses unittest2 test discovery.""" test_suite = unittest.TestSuite parallel_test_suite = ParallelTestSuite - test_runner = unittest.TextTestRunner + test_runner = SuccessTrackingTextTestRunner test_loader = unittest.defaultTestLoader reorder_by = (TestCase, SimpleTestCase) @@ -952,6 +993,20 @@ def build_suite(self, test_labels=None, **kwargs): # _FailedTest objects include things like test modules that couldn't be # found or that couldn't be loaded due to syntax errors. test_types = (unittest.loader._FailedTest, *self.reorder_by) + try: + if os.environ.get("STEPWISE"): + with open("passed.tests", "r") as passed_tests_f: + passed_tests = { + l.strip() for l in passed_tests_f.read().splitlines() + } + else: + passed_tests = set() + except FileNotFoundError: + passed_tests = set() + + if len(passed_tests): + print("Filtering out previously passing tests") + all_tests = [t for t in all_tests if t.id() not in passed_tests] all_tests = list( reorder_tests( all_tests, @@ -1066,6 +1121,19 @@ def get_databases(self, suite): ) return databases + def _update_failed_tracking(self, result): + if result.wasSuccessful(): + try: + print("Removing passed tests") + os.remove("passed.tests") + except FileNotFoundError: + pass + else: + passed_ids = [test.id() for test in result.successes] + with open("passed.tests", "a") as f: + f.write("\n".join(passed_ids)) + print("Wrote passed tests") + def run_tests(self, test_labels, **kwargs): """ Run the unit tests for all the test labels in the provided list. @@ -1088,6 +1156,7 @@ def run_tests(self, test_labels, **kwargs): serialized_aliases=suite.serialized_aliases, ) run_failed = False + result = None try: self.run_checks(databases) result = self.run_suite(suite) @@ -1095,6 +1164,8 @@ def run_tests(self, test_labels, **kwargs): run_failed = True raise finally: + if result is not None: + self._update_failed_tracking(result) try: with self.time_keeper.timed("Total database teardown"): self.teardown_databases(old_config) diff --git a/django/test/testcases.py b/django/test/testcases.py index 8f9ba977a382..688c46ee5ebe 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -38,7 +38,13 @@ from django.core.management.sql import emit_post_migrate_signal from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler from django.core.signals import setting_changed -from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction +from django.db import ( + DEFAULT_DB_ALIAS, + async_connections, + connection, + connections, + transaction, +) from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper from django.forms.fields import CharField from django.http import QueryDict @@ -323,6 +329,56 @@ def debug(self): debug_result = _DebugResult() self._setup_and_call(debug_result, debug=True) + def connect_db_then_run(self, test_method): + + import functools + from contextlib import AsyncExitStack + from django.db import new_connection + + @functools.wraps(test_method) + async def cdb_then_run(*args, **kwargs): + async with AsyncExitStack() as stack: + # connect to all the DBs + for db in self.databases: + aconn = await stack.enter_async_context(new_connection(using=db)) + # import gc + + # refs = gc.get_referents(aconn) + # print(refs) + # import pdb + + # pdb.set_trace() + await test_method(*args, **kwargs) + + return cdb_then_run + + @classmethod + def use_async_connections(cls, test_method): + # set up async connections that will get rollbacked at the + # end of the session + import functools + from contextlib import AsyncExitStack + from django.db import new_connection + + @functools.wraps(test_method) + async def cdb_then_run(self, *args, **kwargs): + async with AsyncExitStack() as stack: + # connect to all the DBs + for db in self.databases: + await stack.enter_async_context( + new_connection(using=db, force_rollback=True) + ) + # import gc + + # refs = gc.get_referents(aconn) + # print(refs) + # import pdb + + # pdb.set_trace() + await test_method(self, *args, **kwargs) + + return cdb_then_run + def _setup_and_call(self, result, debug=False): """ Perform the following in order: pre-setup, run test, post-teardown, @@ -336,9 +392,16 @@ def _setup_and_call(self, result, debug=False): testMethod, "__unittest_skip__", False ) + async_connections._from_testcase = True + # Convert async test methods. if iscoroutinefunction(testMethod): - setattr(self, self._testMethodName, async_to_sync(testMethod)) + setattr( + self, + self._testMethodName, + async_to_sync(testMethod), + # async_to_sync(self.connect_db_then_run(testMethod)), + ) if not skipped: try: @@ -1111,6 +1174,10 @@ def _pre_setup(cls): * If the class has a 'fixtures' attribute, install those fixtures. """ super()._pre_setup() + if not hasattr(cls, "available_apps"): + raise Exception( + "Please define available_apps in TransactionTestCase and its subclasses." + ) if cls.available_apps is not None: apps.set_available_apps(cls.available_apps) cls._available_apps_calls_balanced += 1 diff --git a/django/test/utils.py b/django/test/utils.py index ddb85127dc94..3c85cfb2de5d 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -366,6 +366,25 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False): verbosity=verbosity, keepdb=keepdb, ) + import objgraph + import pdb + + from django.db.backends.postgresql.base import DatabaseWrapper + import gc + + def the_objs(klass): + return [obj for obj in gc.get_objects() if try_isinstance(obj, klass)] + + def try_isinstance(a, b): + try: + return isinstance(a, b) + except: + return False + + active_dbs = [db for db in the_objs(DatabaseWrapper) if db.aconnection] + if len(active_dbs): + print(active_dbs) + pdb.set_trace() connection.creation.destroy_test_db(old_name, verbosity, keepdb) diff --git a/django/utils/asyncio.py b/django/utils/asyncio.py index 1e79f90c2c1b..d17d64c1bff4 100644 --- a/django/utils/asyncio.py +++ b/django/utils/asyncio.py @@ -2,6 +2,8 @@ from asyncio import get_running_loop from functools import wraps +from asgiref.sync import async_to_sync, sync_to_async + from django.core.exceptions import SynchronousOnlyOperation @@ -37,3 +39,16 @@ def inner(*args, **kwargs): return decorator(func) else: return decorator + + +async def alist(to_consume): + """ + This helper method gets a list out of an async iterable + """ + result = [] + async for elt in to_consume: + result.append(elt) + return result + + +agetattr = sync_to_async(getattr) diff --git a/django/utils/codegen/__init__.py b/django/utils/codegen/__init__.py new file mode 100644 index 000000000000..2cdcdae40ced --- /dev/null +++ b/django/utils/codegen/__init__.py @@ -0,0 +1,27 @@ +def _identity(f): + return f + + +def from_codegen(f): + """ + This indicates that the function was gotten from codegen, and + should not be directly modified + """ + return f + + +def generate_unasynced(async_unsafe=False): + """ + This indicates we should unasync this function/method + + async_unsafe indicates whether to add the async_unsafe decorator + """ + + def wrapper(f): + return f + + return wrapper + + +# this marker gets replaced by False when unasyncifying a function +ASYNC_TRUTH_MARKER = True diff --git a/django/utils/codegen/async_helpers.py b/django/utils/codegen/async_helpers.py new file mode 100644 index 000000000000..b5ef59839799 --- /dev/null +++ b/django/utils/codegen/async_helpers.py @@ -0,0 +1,268 @@ +from collections import namedtuple +import libcst as cst +from libcst import FunctionDef, ClassDef, Name, Decorator +from libcst.helpers import get_full_name_for_node + +import argparse +from ast import literal_eval +from typing import Union + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor + + +DecoratorInfo = namedtuple("DecoratorInfo", ["from_codegen", "unasync", "async_unsafe"]) + + +class UnasyncifyMethod(cst.CSTTransformer): + """ + Make a non-sync version of the method + """ + + def __init__(self): + self.await_depth = 0 + + def visit_Await(self, node): + self.await_depth += 1 + + def leave_Await(self, original_node, updated_node): + self.await_depth -= 1 + # we just remove the actual await + return updated_node.expression + + NAMES_TO_REWRITE = { + "aconnection": "connection", + "ASYNC_TRUTH_MARKER": "False", + "acursor": "cursor", + } + + def leave_Name(self, original_node, updated_node): + # some names will get rewritten because we know + # about them + if updated_node.value in self.NAMES_TO_REWRITE: + return updated_node.with_changes( + value=self.NAMES_TO_REWRITE[updated_node.value] + ) + return updated_node + + def unasynced_function_name(self, func_name: str) -> str | None: + """ + Return the function name for an unasync version of this + function (or None if there is no unasync version) + """ + # XXX bit embarassing but... + if func_name == "all": + return None + if func_name.startswith("a"): + return func_name[1:] + elif func_name.startswith("_a"): + return "_" + func_name[2:] + else: + return None + + def leave_Call(self, original_node, updated_node): + if self.await_depth == 0: + # we only transform calls that are part of + # an await expression + return updated_node + + if isinstance(updated_node.func, cst.Name): + func_name: cst.Name = updated_node.func + unasync_name = self.unasynced_function_name(updated_node.func.value) + if unasync_name is not None: + # let's transform it by removing the a + unasync_func_name = func_name.with_changes(value=unasync_name) + return updated_node.with_changes(func=unasync_func_name) + + elif isinstance(updated_node.func, cst.Attribute): + func_name: cst.Name = updated_node.func.attr + unasync_name = self.unasynced_function_name(updated_node.func.attr.value) + if unasync_name is not None: + # let's transform it by removing the a + return updated_node.with_changes( + func=updated_node.func.with_changes( + attr=func_name.with_changes(value=unasync_name) + ) + ) + return updated_node + + def leave_If(self, original_node, updated_node): + + # checking if the original if was "if ASYNC_TRUTH_MARKER" + # (the updated node would have turned this to if False) + if ( + isinstance(original_node.test, cst.Name) + and original_node.test.value == "ASYNC_TRUTH_MARKER" + ): + if updated_node.orelse is not None: + if isinstance(updated_node.orelse, cst.Else): + # unindent + return cst.FlattenSentinel(updated_node.orelse.body.body) + else: + # we seem to have elif continuations so use that + return updated_node.orelse + else: + # if there's no else branch we just remove the node + return cst.RemovalSentinel.REMOVE + return updated_node + + def leave_CompFor(self, original_node, updated_node): + if updated_node.asynchronous is not None: + return updated_node.with_changes(asynchronous=None) + else: + return updated_node + + def leave_For(self, original_node, updated_node): + if updated_node.asynchronous is not None: + return updated_node.with_changes(asynchronous=None) + else: + return updated_node + + def leave_With(self, original_node, updated_node): + if updated_node.asynchronous is not None: + return updated_node.with_changes(asynchronous=None) + else: + return updated_node + + +class UnasyncifyMethodCommand(VisitorBasedCodemodCommand): + DESCRIPTION = "Transform async methods to sync ones" + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + self.class_stack: list[ClassDef] = [] + + def visit_ClassDef(self, original_node): + self.class_stack.append(original_node) + return True + + def leave_ClassDef(self, original_node, updated_node): + self.class_stack.pop() + return updated_node + + def should_be_unasyncified(self, node: FunctionDef): + method_name = get_full_name_for_node(node.name) + # XXX do other checks here as well? + return ( + node.asynchronous + and method_name.startswith("a") + and method_name == "ainit_connection_state" + ) + + def label_as_codegen(self, node: FunctionDef, async_unsafe: bool) -> FunctionDef: + + from_codegen_marker = Decorator(decorator=Name("from_codegen")) + AddImportsVisitor.add_needed_import( + self.context, "django.utils.codegen", "from_codegen" + ) + + decorators_to_add = [from_codegen_marker] + if async_unsafe: + async_unsafe_marker = Decorator(decorator=Name("async_unsafe")) + AddImportsVisitor.add_needed_import( + self.context, "django.utils.asyncio", "async_unsafe" + ) + decorators_to_add.append(async_unsafe_marker) + # we remove generate_unasynced_codegen + return node.with_changes(decorators=[*decorators_to_add, *node.decorators[1:]]) + + def codegenned_func(self, node: FunctionDef) -> bool: + for decorator in node.decorators: + if ( + isinstance(decorator.decorator, Name) + and decorator.decorator.value == "from_codegen" + ): + return True + return False + + generate_unasync_pattern = m.Call( + func=m.Name(value="generate_unasynced"), + ) + + generated_keyword_pattern = m.Arg( + keyword=m.Name(value="async_unsafe"), + value=m.Name(value="True"), + ) + + def decorator_info(self, node: FunctionDef) -> DecoratorInfo: + from_codegen = False + unasync = False + async_unsafe = False + + # we only consider the top decorator, and will copy everything else + if node.decorators: + decorator = node.decorators[0] + if isinstance(decorator.decorator, cst.Name): + if decorator.decorator.value == "from_codegen": + from_codegen = True + elif m.matches(decorator.decorator, self.generate_unasync_pattern): + unasync = True + args = decorator.decorator.args + if len(args) == 0: + async_unsafe = False + elif len(args) == 1: + # assert that it's async_unsafe, our only supported + # keyword for now + assert m.matches( + args[0], self.generated_keyword_pattern + ), f"We only support async_unsafe=True as a keyword argument, got {args}" + async_unsafe = True + else: + raise ValueError( + "generate_unasynced only supports 0 or 1 arguments" + ) + return DecoratorInfo(from_codegen, unasync, async_unsafe) + + def decorator_names(self, node: FunctionDef) -> list[str]: + # get the names of the decorators on this function + # this doesn't try very hard + return [ + decorator.decorator.value + for decorator in node.decorators + if isinstance(decorator.decorator, Name) + ] + + def calculate_new_name(self, old_name): + if old_name.startswith("test_async_"): + # test_async_foo -> test_foo + return old_name.replace("test_async_", "test_", 1) + if old_name.startswith("_a"): + # _ainsert -> _insert + return old_name.replace("_a", "_", 1) + if old_name.startswith("a"): + # aget -> get + return old_name[1:] + raise ValueError( + f""" + Unknown name replacement pasttern for {old_name} + """ + ) + + def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef): + decorator_info = self.decorator_info(updated_node) + # if we are looking at something that's already codegen, drop it + # (it will get regenerated) + if decorator_info.from_codegen: + return cst.RemovalSentinel.REMOVE + + if decorator_info.unasync: + new_name = self.calculate_new_name( + get_full_name_for_node(updated_node.name) + ) + + unasynced_func = updated_node.with_changes( + name=Name(new_name), + asynchronous=None, + ) + unasynced_func = self.label_as_codegen( + unasynced_func, async_unsafe=decorator_info.async_unsafe + ) + unasynced_func = unasynced_func.visit(UnasyncifyMethod()) + + # while here the async version is the canonical version, we place + # the unasync version up on top + return cst.FlattenSentinel([unasynced_func, updated_node]) + else: + return updated_node diff --git a/django/utils/connection.py b/django/utils/connection.py index a278598f251e..609450181a89 100644 --- a/django/utils/connection.py +++ b/django/utils/connection.py @@ -36,6 +36,8 @@ class BaseConnectionHandler: exception_class = ConnectionDoesNotExist thread_critical = False + LOG_HITS = False + def __init__(self, settings=None): self._settings = settings self._connections = Local(self.thread_critical) @@ -53,16 +55,30 @@ def configure_settings(self, settings): def create_connection(self, alias): raise NotImplementedError("Subclasses must implement create_connection().") - def __getitem__(self, alias): + from django.utils.asyncio import async_unsafe + + def get_item(self, alias, raise_on_miss=False): + if self.LOG_HITS: + print(f"CH.__getitem__[{alias}]") try: - return getattr(self._connections, alias) + result = getattr(self._connections, alias) + if self.LOG_HITS: + print("CACHE HIT") + return result except AttributeError: + if raise_on_miss: + raise + if self.LOG_HITS: + print("CACHE MISS") if alias not in self.settings: raise self.exception_class(f"The connection '{alias}' doesn't exist.") conn = self.create_connection(alias) setattr(self._connections, alias, conn) return conn + def __getitem__(self, alias): + return self.get_item(alias) + def __setitem__(self, key, value): setattr(self._connections, key, value) diff --git a/do_tests.nu b/do_tests.nu new file mode 100755 index 000000000000..ebae8f517216 --- /dev/null +++ b/do_tests.nu @@ -0,0 +1,14 @@ +#!/usr/bin/env nu +def main [--codegen] { + if $codegen { + print "Codegenning..." + ./scripts/run_codegen.sh + } + + print "Running with test_postgresql_async" + ./tests/runtests.py async --settings test_postgresql_async --parallel=1 --debug-sql + print "Running with test_sqlite" + ./tests/runtests.py async --settings test_sqlite + print "Running with test_postgresql" + ./tests/runtests.py async --settings test_postgresql +} diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index 15dad66b5443..a113c977b930 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -219,6 +219,14 @@ Database backends * MySQL connections now default to using the ``utf8mb4`` character set, instead of ``utf8``, which is an alias for the deprecated character set ``utf8mb3``. +* It is now possible to perform asynchronous raw SQL queries using an async cursor. + This is only possible on backends that support async-native connections. + Currently only supported in PostreSQL with the ``django.db.backends.postgresql`` + backend. +* It is now possible to perform asynchronous raw SQL queries using an async + cursor, if the backend supports async-native connections. This is only + supported on PostgreSQL with ``psycopg`` 3.1.8+. See + :ref:`async-connection-cursor` for more details. * Oracle backends now support :ref:`connection pools `, by setting ``"pool"`` in the :setting:`OPTIONS` part of your database configuration. @@ -433,6 +441,19 @@ MySQL connections now default to using the ``utf8mb4`` character set, instead of ``utf8``, which is an alias for the deprecated character set ``utf8mb3``. ``utf8mb3`` can be specified in the ``OPTIONS`` part of the ``DATABASES`` setting, if needed for legacy databases. +Models +------ + +* Multiple changes have been made to the undocumented `django.db.models.sql.compiler.SQLCompiler.execute_sql`` + method. + + * ``django.db.models.sql.constants.CURSOR`` has been removed as a possible value + for ``SQLCompiler.execute_sql``'s ``result_type`` parameter. Instead, + ``LEAK_CURSOR`` should be used if you want to receive the cursor back. + * ``ROW_COUNT`` has been added as a result type, which returns the number of rows + returned by the query directly, closing the cursor in the process. + * ``UpdateSQLCompiler.execute_sql`` now only accepts ``NO_RESULT`` and ``LEAK_CURSOR`` + as result types. Miscellaneous ------------- diff --git a/docs/topics/db/sql.txt b/docs/topics/db/sql.txt index 42143fd1189a..93406b510b3b 100644 --- a/docs/topics/db/sql.txt +++ b/docs/topics/db/sql.txt @@ -403,6 +403,33 @@ is equivalent to:: finally: c.close() +.. _async-connection-cursor: + +Async Connections and cursors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 5.2 + +On backends that support async-native connections, you can request an async +cursor:: + + from django.db import new_connection + + async with new_connection() as connection: + async with connection.acursor() as c: + await c.aexecute(...) + +Async cursors provide the following methods: + +* ``.aexecute()`` +* ``.aexecutemany()`` +* ``.afetchone()`` +* ``.afetchmany()`` +* ``.afetchall()`` +* ``.acopy()`` +* ``.astream()`` +* ``.ascroll()`` + Calling stored procedures ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/notes.txt b/notes.txt new file mode 100644 index 000000000000..ca70903df2f4 --- /dev/null +++ b/notes.txt @@ -0,0 +1,23 @@ + +Running: + +tests/runtests.py --settings=test_postgresql generic_relations.tests --noinput + +^ spits out an "AsyncCursor.close was never aawaited" thing + +---- + +I need to write out async with new_connection blocks in tests (maybe my codemod can look +at an environment variable?) + + +---- + +assertNumQueries support in an async context.... + + +---- + +skipUnlessDBFeature etc.... none of this stuff works with async tests. You can tell because test blocks with those labels are never awaited + +--- diff --git a/scripts/run_codegen.sh b/scripts/run_codegen.sh new file mode 100755 index 000000000000..8748a91ff1c4 --- /dev/null +++ b/scripts/run_codegen.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env sh + +# This script runs libcst codegen +python3 -m libcst.tool codemod async_helpers.UnasyncifyMethodCommand django +python3 -m libcst.tool codemod async_helpers.UnasyncifyMethodCommand tests diff --git a/tests/.coveragerc b/tests/.coveragerc index f1ec004854fd..ca2684f1e206 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -5,6 +5,7 @@ data_file = ${RUNTESTS_DIR-.}/.coverages/.coverage omit = */django/utils/autoreload.py source = django +dynamic_context = test_function [report] ignore_errors = True diff --git a/tests/async/models.py b/tests/async/models.py index a09ff799146d..55f04ffe7d13 100644 --- a/tests/async/models.py +++ b/tests/async/models.py @@ -13,3 +13,12 @@ class SimpleModel(models.Model): class ManyToManyModel(models.Model): simples = models.ManyToManyField("SimpleModel") + + +class ModelWithSyncOverride(models.Model): + field = models.IntegerField() + + def save(self, *args, **kwargs): + # we increment our field right before saving + self.field += 1 + super().save(*args, **kwargs) diff --git a/tests/async/test_async_auth.py b/tests/async/test_async_auth.py index 3d5a6b678d00..5096692a359c 100644 --- a/tests/async/test_async_auth.py +++ b/tests/async/test_async_auth.py @@ -7,14 +7,15 @@ ) from django.contrib.auth.models import AnonymousUser, User from django.http import HttpRequest -from django.test import TestCase, override_settings +from django.test import TransactionTestCase, TestCase, override_settings from django.utils.deprecation import RemovedInDjango61Warning -class AsyncAuthTest(TestCase): - @classmethod - def setUpTestData(cls): - cls.test_user = User.objects.create_user( +class AsyncAuthTest(TransactionTestCase): + available_apps = ["django.contrib.auth"] + + def setUp(self): + self.test_user = User.objects.create_user( "testuser", "test@example.com", "testpw" ) diff --git a/tests/async/test_async_model_methods.py b/tests/async/test_async_model_methods.py index d988d7befcb4..27972c5afd1e 100644 --- a/tests/async/test_async_model_methods.py +++ b/tests/async/test_async_model_methods.py @@ -1,18 +1,42 @@ -from django.test import TestCase +from django.test import TestCase, TransactionTestCase -from .models import SimpleModel +from .models import ModelWithSyncOverride, SimpleModel +from django.db import transaction, new_connection +from asgiref.sync import async_to_sync -class AsyncModelOperationTest(TestCase): - @classmethod - def setUpTestData(cls): - cls.s1 = SimpleModel.objects.create(field=0) +# XXX should there be a way of catching this +# class AsyncSyncCominglingTest(TransactionTestCase): +# available_apps = ["async"] + +# async def change_model_with_async(self, obj): +# obj.field = 10 +# await obj.asave() + +# def test_transaction_async_comingling(self): +# with transaction.atomic(): +# s1 = SimpleModel.objects.create(field=0) +# async_to_sync(self.change_model_with_async)(s1) + + +class AsyncModelOperationTest(TransactionTestCase): + + available_apps = ["async"] + + def setUp(self): + super().setUp() + self.s1 = SimpleModel.objects.create(field=0) + + @TestCase.use_async_connections async def test_asave(self): - self.s1.field = 10 - await self.s1.asave() - refetched = await SimpleModel.objects.aget() - self.assertEqual(refetched.field, 10) + from django.db.backends.utils import block_sync_ops + + with block_sync_ops(): + self.s1.field = 10 + await self.s1.asave() + refetched = await SimpleModel.objects.aget() + self.assertEqual(refetched.field, 10) async def test_adelete(self): await self.s1.adelete() @@ -34,3 +58,32 @@ async def test_arefresh_from_db_from_queryset(self): from_queryset=SimpleModel.objects.filter(field__gt=0) ) self.assertEqual(self.s1.field, 20) + + +class TestAsyncModelOverrides(TransactionTestCase): + available_apps = ["async"] + + def setUp(self): + super().setUp() + self.s1 = ModelWithSyncOverride.objects.create(field=5) + + def test_sync_variant(self): + # when saving a ModelWithSyncOverride, we bump up the value of field + self.s1.field = 6 + self.s1.save() + self.assertEqual(self.s1.field, 7) + + async def test_override_handling_in_cxn_context(self): + # when saving with asave, we're actually going to fallback to save + # (including in a new_connection context) + async with new_connection(force_rollback=True): + self.s1.field = 6 + await self.s1.asave() + self.assertEqual(self.s1.field, 7) + + async def test_override_handling(self): + # when saving with asave, we're actually going to fallback to save + # (including outside a new_connection context) + self.s1.field = 6 + await self.s1.asave() + self.assertEqual(self.s1.field, 7) diff --git a/tests/async/test_async_queryset.py b/tests/async/test_async_queryset.py index 374b4576f98f..d063f290e323 100644 --- a/tests/async/test_async_queryset.py +++ b/tests/async/test_async_queryset.py @@ -4,31 +4,37 @@ from asgiref.sync import async_to_sync, sync_to_async -from django.db import NotSupportedError, connection +from django.db import NotSupportedError, connection, new_connection from django.db.models import Prefetch, Sum -from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.test import ( + TransactionTestCase, + TestCase, + skipIfDBFeature, + skipUnlessDBFeature, +) from .models import RelatedModel, SimpleModel -class AsyncQuerySetTest(TestCase): - @classmethod - def setUpTestData(cls): - cls.s1 = SimpleModel.objects.create( +class AsyncQuerySetTest(TransactionTestCase): + available_apps = ["async"] + + def setUp(self): + self.s1 = SimpleModel.objects.create( field=1, created=datetime(2022, 1, 1, 0, 0, 0), ) - cls.s2 = SimpleModel.objects.create( + self.s2 = SimpleModel.objects.create( field=2, created=datetime(2022, 1, 1, 0, 0, 1), ) - cls.s3 = SimpleModel.objects.create( + self.s3 = SimpleModel.objects.create( field=3, created=datetime(2022, 1, 1, 0, 0, 2), ) - cls.r1 = RelatedModel.objects.create(simple=cls.s1) - cls.r2 = RelatedModel.objects.create(simple=cls.s2) - cls.r3 = RelatedModel.objects.create(simple=cls.s3) + self.r1 = RelatedModel.objects.create(simple=self.s1) + self.r2 = RelatedModel.objects.create(simple=self.s2) + self.r3 = RelatedModel.objects.create(simple=self.s3) @staticmethod def _get_db_feature(connection_, feature_name): @@ -88,6 +94,10 @@ async def test_acount_cached_result(self): async def test_aget(self): instance = await SimpleModel.objects.aget(field=1) self.assertEqual(instance, self.s1) + with self.assertRaises(SimpleModel.MultipleObjectsReturned): + await SimpleModel.objects.aget() + with self.assertRaises(SimpleModel.DoesNotExist): + await SimpleModel.objects.aget(field=98) async def test_acreate(self): await SimpleModel.objects.acreate(field=4) @@ -115,17 +125,23 @@ async def test_aupdate_or_create(self): self.assertIs(created, True) self.assertEqual(instance.field, 6) - @skipUnlessDBFeature("has_bulk_insert") - @async_to_sync + def ensure_feature(self, *args): + if not all(getattr(connection.features, feature, False) for feature in args): + self.skipTest(f"Database doesn't support feature(s): {', '.join(args)}") + + def skip_if_feature(self, *args): + if any(getattr(connection.features, feature, False) for feature in args): + self.skipTest(f"Database supports feature(s): {', '.join(args)}") + async def test_abulk_create(self): + self.ensure_feature("has_bulk_insert") instances = [SimpleModel(field=i) for i in range(10)] qs = await SimpleModel.objects.abulk_create(instances) self.assertEqual(len(qs), 10) - @skipUnlessDBFeature("has_bulk_insert", "supports_update_conflicts") - @skipIfDBFeature("supports_update_conflicts_with_target") - @async_to_sync async def test_update_conflicts_unique_field_unsupported(self): + self.ensure_feature("has_bulk_insert", "support_update_conflicts") + self.skip_if_feature("supports_update_conflicts_with_target") msg = ( "This database backend does not support updating conflicts with specifying " "unique fields that can trigger the upsert." @@ -223,9 +239,8 @@ async def test_adelete(self): qs = [o async for o in SimpleModel.objects.all()] self.assertCountEqual(qs, [self.s1, self.s3]) - @skipUnlessDBFeature("supports_explaining_query_execution") - @async_to_sync async def test_aexplain(self): + self.ensure_feature("supports_explaining_query_execution") supported_formats = await sync_to_async(self._get_db_feature)( connection, "supported_explain_formats" ) @@ -257,3 +272,32 @@ async def test_raw(self): sql = "SELECT id, field FROM async_simplemodel WHERE created=%s" qs = SimpleModel.objects.raw(sql, [self.s1.created]) self.assertEqual([o async for o in qs], [self.s1]) + + +# for all the test methods on AsyncQuerySetTest +# we will add a variant, that first opens a new +# async connection + + +def _tests(): + return [(attr, getattr(AsyncQuerySetTest, attr)) for attr in dir(AsyncQuerySetTest)] + + +def wrap_test(original_test, test_name): + """ + Given an async test, provide an async test that + is generating a new connection + """ + new_test_name = test_name + "_new_cxn" + + async def wrapped_test(self): + async with new_connection(force_rollback=True): + await original_test(self) + + wrapped_test.__name__ = new_test_name + return (new_test_name, wrapped_test) + + +for test_name, test in _tests(): + new_name, new_test = wrap_test(test, test_name) + setattr(AsyncQuerySetTest, new_name, new_test) diff --git a/tests/auth_tests/test_remote_user.py b/tests/auth_tests/test_remote_user.py index 85de931c1a08..4d52eca7ddae 100644 --- a/tests/auth_tests/test_remote_user.py +++ b/tests/auth_tests/test_remote_user.py @@ -10,13 +10,15 @@ AsyncClient, Client, TestCase, + TransactionTestCase, modify_settings, override_settings, ) @override_settings(ROOT_URLCONF="auth_tests.urls") -class RemoteUserTest(TestCase): +class RemoteUserTest(TransactionTestCase): + available_apps = ["auth_tests", "django.contrib.auth", "django.contrib.admin"] middleware = "django.contrib.auth.middleware.RemoteUserMiddleware" backend = "django.contrib.auth.backends.RemoteUserBackend" header = "REMOTE_USER" diff --git a/tests/backends/base/test_base_async.py b/tests/backends/base/test_base_async.py new file mode 100644 index 000000000000..39feadcccb63 --- /dev/null +++ b/tests/backends/base/test_base_async.py @@ -0,0 +1,14 @@ +import unittest + +from django.db import connection, new_connection +from django.test import SimpleTestCase + + +class AsyncDatabaseWrapperTests(SimpleTestCase): + @unittest.skipUnless(connection.supports_async is True, "Async DB test") + async def test_async_cursor(self): + async with new_connection(force_rollback=True) as conn: + async with conn.acursor() as cursor: + await cursor.execute("SELECT 1") + result = (await cursor.fetchone())[0] + self.assertEqual(result, 1) diff --git a/tests/basic/tests.py b/tests/basic/tests.py index f6eabfaed7e8..1774298f647f 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -2,6 +2,7 @@ import threading from datetime import datetime, timedelta from unittest import mock +import unittest from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist from django.db import ( @@ -10,6 +11,7 @@ connection, connections, models, + new_connection, transaction, ) from django.db.models.manager import BaseManager @@ -21,6 +23,7 @@ skipUnlessDBFeature, ) from django.test.utils import CaptureQueriesContext +from django.utils.codegen import from_codegen, generate_unasynced from django.utils.connection import ConnectionDoesNotExist from django.utils.translation import gettext_lazy @@ -375,29 +378,61 @@ def test_extra_method_select_argument_with_dashes(self): ) self.assertEqual(articles[0].undashedvalue, 2) + @from_codegen def test_create_relation_with_gettext_lazy(self): """ gettext_lazy objects work when saving model instances through various methods. Refs #10498. """ - notlazy = "test" - lazy = gettext_lazy(notlazy) - Article.objects.create(headline=lazy, pub_date=datetime.now()) - article = Article.objects.get() - self.assertEqual(article.headline, notlazy) - # test that assign + save works with Promise objects - article.headline = lazy - article.save() - self.assertEqual(article.headline, notlazy) - # test .update() - Article.objects.update(headline=lazy) - article = Article.objects.get() - self.assertEqual(article.headline, notlazy) - # still test bulk_create() - Article.objects.all().delete() - Article.objects.bulk_create([Article(headline=lazy, pub_date=datetime.now())]) - article = Article.objects.get() - self.assertEqual(article.headline, notlazy) + with new_connection(force_rollback=True): + notlazy = "test" + lazy = gettext_lazy(notlazy) + Article.objects.create(headline=lazy, pub_date=datetime.now()) + article = Article.objects.get() + self.assertEqual(article.headline, notlazy) + # test that assign + save works with Promise objects + article.headline = lazy + article.save() + self.assertEqual(article.headline, notlazy) + # test .update() + Article.objects.update(headline=lazy) + article = Article.objects.get() + self.assertEqual(article.headline, notlazy) + # still test bulk_create() + Article.objects.all().delete() + Article.objects.bulk_create( + [Article(headline=lazy, pub_date=datetime.now())] + ) + article = Article.objects.get() + self.assertEqual(article.headline, notlazy) + + @generate_unasynced() + async def test_async_create_relation_with_gettext_lazy(self): + """ + gettext_lazy objects work when saving model instances + through various methods. Refs #10498. + """ + async with new_connection(force_rollback=True): + notlazy = "test" + lazy = gettext_lazy(notlazy) + await Article.objects.acreate(headline=lazy, pub_date=datetime.now()) + article = await Article.objects.aget() + self.assertEqual(article.headline, notlazy) + # test that assign + save works with Promise objects + article.headline = lazy + await article.asave() + self.assertEqual(article.headline, notlazy) + # test .update() + await Article.objects.aupdate(headline=lazy) + article = await Article.objects.aget() + self.assertEqual(article.headline, notlazy) + # still test bulk_create() + await Article.objects.all().adelete() + await Article.objects.abulk_create( + [Article(headline=lazy, pub_date=datetime.now())] + ) + article = await Article.objects.aget() + self.assertEqual(article.headline, notlazy) def test_emptyqs(self): msg = "EmptyQuerySet can't be instantiated" @@ -752,6 +787,8 @@ class ManagerTest(SimpleTestCase): "exists", "contains", "explain", + "_ainsert", + "_aupdate", "_insert", "_update", "raw", diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 7b86a2def54d..8c29c69cb4c9 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -8,6 +8,7 @@ OperationalError, ProgrammingError, connection, + new_connection, ) from django.db.models import FileField, Value from django.db.models.functions import Lower, Now @@ -17,6 +18,7 @@ skipIfDBFeature, skipUnlessDBFeature, ) +from django.utils.codegen import from_codegen, generate_unasynced from .models import ( BigAutoFieldModel, @@ -47,23 +49,45 @@ def setUp(self): Country(name="Czech Republic", iso_two_letter="CZ"), ] + @from_codegen def test_simple(self): - created = Country.objects.bulk_create(self.data) - self.assertEqual(created, self.data) - self.assertQuerySetEqual( - Country.objects.order_by("-name"), - [ - "United States of America", - "The Netherlands", - "Germany", - "Czech Republic", - ], - attrgetter("name"), - ) + with new_connection(force_rollback=True): + created = Country.objects.bulk_create(self.data) + self.assertEqual(created, self.data) - created = Country.objects.bulk_create([]) - self.assertEqual(created, []) - self.assertEqual(Country.objects.count(), 4) + self.assertListEqual( + [c.name for c in Country.objects.order_by("-name")], + [ + "United States of America", + "The Netherlands", + "Germany", + "Czech Republic", + ], + ) + + created = Country.objects.bulk_create([]) + self.assertEqual(created, []) + self.assertEqual(Country.objects.count(), 4) + + @generate_unasynced() + async def test_async_simple(self): + async with new_connection(force_rollback=True): + created = await Country.objects.abulk_create(self.data) + self.assertEqual(created, self.data) + + self.assertListEqual( + [c.name async for c in Country.objects.order_by("-name")], + [ + "United States of America", + "The Netherlands", + "Germany", + "Czech Republic", + ], + ) + + created = await Country.objects.abulk_create([]) + self.assertEqual(created, []) + self.assertEqual(await Country.objects.acount(), 4) @skipUnlessDBFeature("has_bulk_insert") def test_efficiency(self): @@ -92,26 +116,51 @@ def test_long_and_short_text(self): ) self.assertEqual(Country.objects.count(), 4) + @from_codegen def test_multi_table_inheritance_unsupported(self): - expected_message = "Can't bulk create a multi-table inherited model" - with self.assertRaisesMessage(ValueError, expected_message): - Pizzeria.objects.bulk_create( - [ - Pizzeria(name="The Art of Pizza"), - ] - ) - with self.assertRaisesMessage(ValueError, expected_message): - ProxyMultiCountry.objects.bulk_create( - [ - ProxyMultiCountry(name="Fillory", iso_two_letter="FL"), - ] - ) - with self.assertRaisesMessage(ValueError, expected_message): - ProxyMultiProxyCountry.objects.bulk_create( - [ - ProxyMultiProxyCountry(name="Fillory", iso_two_letter="FL"), - ] - ) + with new_connection(force_rollback=True): + expected_message = "Can't bulk create a multi-table inherited model" + with self.assertRaisesMessage(ValueError, expected_message): + Pizzeria.objects.bulk_create( + [ + Pizzeria(name="The Art of Pizza"), + ] + ) + with self.assertRaisesMessage(ValueError, expected_message): + ProxyMultiCountry.objects.bulk_create( + [ + ProxyMultiCountry(name="Fillory", iso_two_letter="FL"), + ] + ) + with self.assertRaisesMessage(ValueError, expected_message): + ProxyMultiProxyCountry.objects.bulk_create( + [ + ProxyMultiProxyCountry(name="Fillory", iso_two_letter="FL"), + ] + ) + + @generate_unasynced() + async def test_async_multi_table_inheritance_unsupported(self): + async with new_connection(force_rollback=True): + expected_message = "Can't bulk create a multi-table inherited model" + with self.assertRaisesMessage(ValueError, expected_message): + await Pizzeria.objects.abulk_create( + [ + Pizzeria(name="The Art of Pizza"), + ] + ) + with self.assertRaisesMessage(ValueError, expected_message): + await ProxyMultiCountry.objects.abulk_create( + [ + ProxyMultiCountry(name="Fillory", iso_two_letter="FL"), + ] + ) + with self.assertRaisesMessage(ValueError, expected_message): + await ProxyMultiProxyCountry.objects.abulk_create( + [ + ProxyMultiProxyCountry(name="Fillory", iso_two_letter="FL"), + ] + ) def test_proxy_inheritance_supported(self): ProxyCountry.objects.bulk_create( @@ -253,20 +302,39 @@ def test_large_batch_mixed_efficiency(self): ) self.assertLess(len(connection.queries), 10) + @from_codegen def test_explicit_batch_size(self): - objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)] - num_objs = len(objs) - TwoFields.objects.bulk_create(objs, batch_size=1) - self.assertEqual(TwoFields.objects.count(), num_objs) - TwoFields.objects.all().delete() - TwoFields.objects.bulk_create(objs, batch_size=2) - self.assertEqual(TwoFields.objects.count(), num_objs) - TwoFields.objects.all().delete() - TwoFields.objects.bulk_create(objs, batch_size=3) - self.assertEqual(TwoFields.objects.count(), num_objs) - TwoFields.objects.all().delete() - TwoFields.objects.bulk_create(objs, batch_size=num_objs) - self.assertEqual(TwoFields.objects.count(), num_objs) + with new_connection(force_rollback=True): + objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)] + num_objs = len(objs) + TwoFields.objects.bulk_create(objs, batch_size=1) + self.assertEqual(TwoFields.objects.count(), num_objs) + TwoFields.objects.all().delete() + TwoFields.objects.bulk_create(objs, batch_size=2) + self.assertEqual(TwoFields.objects.count(), num_objs) + TwoFields.objects.all().delete() + TwoFields.objects.bulk_create(objs, batch_size=3) + self.assertEqual(TwoFields.objects.count(), num_objs) + TwoFields.objects.all().delete() + TwoFields.objects.bulk_create(objs, batch_size=num_objs) + self.assertEqual(TwoFields.objects.count(), num_objs) + + @generate_unasynced() + async def test_async_explicit_batch_size(self): + async with new_connection(force_rollback=True): + objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)] + num_objs = len(objs) + await TwoFields.objects.abulk_create(objs, batch_size=1) + self.assertEqual(await TwoFields.objects.acount(), num_objs) + await TwoFields.objects.all().adelete() + await TwoFields.objects.abulk_create(objs, batch_size=2) + self.assertEqual(await TwoFields.objects.acount(), num_objs) + await TwoFields.objects.all().adelete() + await TwoFields.objects.abulk_create(objs, batch_size=3) + self.assertEqual(await TwoFields.objects.acount(), num_objs) + await TwoFields.objects.all().adelete() + await TwoFields.objects.abulk_create(objs, batch_size=num_objs) + self.assertEqual(await TwoFields.objects.acount(), num_objs) def test_empty_model(self): NoFields.objects.bulk_create([NoFields() for i in range(2)]) @@ -442,6 +510,12 @@ def test_invalid_batch_size_exception(self): with self.assertRaisesMessage(ValueError, msg): Country.objects.bulk_create([], batch_size=-1) + async def test_invalid_batch_size_exception_async(self): + msg = "Batch size must be a positive integer." + async with new_connection(force_rollback=True): + with self.assertRaisesMessage(ValueError, msg): + await Country.objects.abulk_create([], batch_size=-1) + @skipIfDBFeature("supports_update_conflicts") def test_update_conflicts_unsupported(self): msg = "This database backend does not support updating conflicts." diff --git a/tests/composite_pk/test_aggregate.py b/tests/composite_pk/test_aggregate.py index d852fdce30c0..59bc64a01c4f 100644 --- a/tests/composite_pk/test_aggregate.py +++ b/tests/composite_pk/test_aggregate.py @@ -1,5 +1,6 @@ from django.db.models import Count, Max, Q from django.test import TestCase +from django.utils.codegen import from_codegen, generate_unasynced from .models import Comment, Tenant, User @@ -137,7 +138,14 @@ def test_order_by_comments_id_count(self): (self.user_3, self.user_1, self.user_2), ) + @from_codegen def test_max_pk(self): msg = "Max expression does not support composite primary keys." with self.assertRaisesMessage(ValueError, msg): Comment.objects.aggregate(Max("pk")) + + @generate_unasynced() + async def test_async_max_pk(self): + msg = "Max expression does not support composite primary keys." + with self.assertRaisesMessage(ValueError, msg): + await Comment.objects.aaggregate(Max("pk")) diff --git a/tests/db_utils/tests.py b/tests/db_utils/tests.py index 4028a8acdf3e..431240f3293f 100644 --- a/tests/db_utils/tests.py +++ b/tests/db_utils/tests.py @@ -1,10 +1,25 @@ """Tests for django.db.utils.""" +import asyncio +import concurrent.futures import unittest +from unittest import mock from django.core.exceptions import ImproperlyConfigured -from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection -from django.db.utils import ConnectionHandler, load_backend +from django.db import ( + DEFAULT_DB_ALIAS, + NotSupportedError, + ProgrammingError, + async_connections, + connection, + new_connection, +) +from django.db.utils import ( + AsyncAlias, + AsyncConnectionHandler, + ConnectionHandler, + load_backend, +) from django.test import SimpleTestCase, TestCase from django.utils.connection import ConnectionDoesNotExist @@ -90,3 +105,81 @@ def test_load_backend_invalid_name(self): with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm: load_backend("foo") self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'") + + +class AsyncConnectionTests(SimpleTestCase): + def run_pool(self, coro, count=2): + def fn(): + asyncio.run(coro()) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + futures = [] + for _ in range(count): + futures.append(executor.submit(fn)) + + for future in concurrent.futures.as_completed(futures): + exc = future.exception() + if exc is not None: + raise exc + + def test_async_alias(self): + alias = AsyncAlias() + assert len(alias) == 0 + assert alias.connections == [] + + async def coro(): + assert len(alias) == 0 + alias.add_connection(mock.Mock()) + alias.pop() + + self.run_pool(coro) + + def test_async_connection_handler(self): + aconns = AsyncConnectionHandler() + assert aconns.empty is True + assert aconns["default"].connections == [] + + async def coro(): + assert aconns["default"].connections == [] + aconns.add_connection("default", mock.Mock()) + aconns.pop_connection("default") + + self.run_pool(coro) + + @unittest.skipUnless(connection.supports_async is True, "Async DB test") + def test_new_connection_threading(self): + async def coro(): + assert async_connections.empty is True + async with new_connection(force_rollback=True) as connection: + async with connection.acursor() as c: + await c.execute("SELECT 1") + + self.run_pool(coro) + + @unittest.skipUnless(connection.supports_async is True, "Async DB test") + async def test_new_connection(self): + with self.assertRaises(ConnectionDoesNotExist): + async_connections.get_connection(DEFAULT_DB_ALIAS) + + async with new_connection(force_rollback=True) as aconn: + conn1 = async_connections.get_connection(DEFAULT_DB_ALIAS) + self.assertEqual(conn1, aconn) + self.assertIsNotNone(conn1.aconnection) + async with new_connection(force_rollback=True): + conn2 = async_connections.get_connection(DEFAULT_DB_ALIAS) + self.assertIsNotNone(conn1.aconnection) + self.assertIsNotNone(conn2.aconnection) + self.assertNotEqual(conn1.aconnection, conn2.aconnection) + + self.assertIsNotNone(conn1.aconnection) + self.assertIsNone(conn2.aconnection) + self.assertIsNone(conn1.aconnection) + + with self.assertRaises(ConnectionDoesNotExist): + async_connections.get_connection(DEFAULT_DB_ALIAS) + + @unittest.skipUnless(connection.supports_async is False, "Sync DB test") + async def test_new_connection_on_sync(self): + with self.assertRaises(NotSupportedError): + async with new_connection(force_rollback=True): + async_connections.get_connection(DEFAULT_DB_ALIAS) diff --git a/tests/defer/tests.py b/tests/defer/tests.py index 989b5c63d788..6b76ab612d72 100644 --- a/tests/defer/tests.py +++ b/tests/defer/tests.py @@ -1,5 +1,10 @@ +from unittest import expectedFailure +from unittest.case import skip from django.core.exceptions import FieldDoesNotExist, FieldError +from django.db import new_connection from django.test import SimpleTestCase, TestCase +from django.utils.asyncio import alist, agetattr +from django.utils.codegen import from_codegen, generate_unasynced from .models import ( BigChild, @@ -231,6 +236,8 @@ def test_only_subclass(self): class TestDefer2(AssertionMixin, TestCase): + + @from_codegen def test_defer_proxy(self): """ Ensure select_related together with only on a proxy model behaves @@ -238,13 +245,34 @@ def test_defer_proxy(self): """ related = Secondary.objects.create(first="x1", second="x2") ChildProxy.objects.create(name="p1", value="xx", related=related) - children = ChildProxy.objects.select_related().only("id", "name") + children = list(ChildProxy.objects.select_related().only("id", "name")) self.assertEqual(len(children), 1) child = children[0] self.assert_delayed(child, 2) self.assertEqual(child.name, "p1") self.assertEqual(child.value, "xx") + # maybe there is actually no answer for attribute access in await contexts + # but that feels very weird to me + @skip("XXX Proxy object stuff is weird") + @generate_unasynced() + async def test_async_defer_proxy(self): + """ + Ensure select_related together with only on a proxy model behaves + as expected. See #17876. + """ + async with new_connection(force_rollback=True): + related = await Secondary.objects.acreate(first="x1", second="x2") + await ChildProxy.objects.acreate(name="p1", value="xx", related=related) + children = await alist( + ChildProxy.objects.select_related().only("id", "name") + ) + self.assertEqual(len(children), 1) + child = children[0] + self.assert_delayed(child, 2) + self.assertEqual(await agetattr(child, "name"), "p1") + self.assertEqual(await agetattr(child, "value"), "xx") + def test_defer_inheritance_pk_chaining(self): """ When an inherited model is fetched from the DB, its PK is also fetched. diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index e0c6fe2db756..62abdc9a312e 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -1,8 +1,12 @@ from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.prefetch import GenericPrefetch from django.core.exceptions import FieldError +from django.db import new_connection from django.db.models import Q, prefetch_related_objects from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature +from django.test.testcases import TransactionTestCase +from django.utils.asyncio import alist +from django.utils.codegen import from_codegen, generate_unasynced from .models import ( AllowsNullGFK, @@ -575,21 +579,41 @@ def test_get_or_create(self): self.assertEqual(tag.tag, "shiny") self.assertEqual(tag.content_object.id, quartz.id) + @from_codegen def test_update_or_create_defaults(self): - # update_or_create should work with virtual fields (content_object) - quartz = Mineral.objects.create(name="Quartz", hardness=7) - diamond = Mineral.objects.create(name="Diamond", hardness=7) - tag, created = TaggedItem.objects.update_or_create( - tag="shiny", defaults={"content_object": quartz} - ) - self.assertTrue(created) - self.assertEqual(tag.content_object.id, quartz.id) + with new_connection(force_rollback=True): + # update_or_create should work with virtual fields (content_object) + quartz = Mineral.objects.create(name="Quartz", hardness=7) + diamond = Mineral.objects.create(name="Diamond", hardness=7) + tag, created = TaggedItem.objects.update_or_create( + tag="shiny", defaults={"content_object": quartz} + ) + self.assertTrue(created) + self.assertEqual(tag.content_object.id, quartz.id) - tag, created = TaggedItem.objects.update_or_create( - tag="shiny", defaults={"content_object": diamond} - ) - self.assertFalse(created) - self.assertEqual(tag.content_object.id, diamond.id) + tag, created = TaggedItem.objects.update_or_create( + tag="shiny", defaults={"content_object": diamond} + ) + self.assertFalse(created) + self.assertEqual(tag.content_object.id, diamond.id) + + @generate_unasynced() + async def test_async_update_or_create_defaults(self): + async with new_connection(force_rollback=True): + # update_or_create should work with virtual fields (content_object) + quartz = await Mineral.objects.acreate(name="Quartz", hardness=7) + diamond = await Mineral.objects.acreate(name="Diamond", hardness=7) + tag, created = await TaggedItem.objects.aupdate_or_create( + tag="shiny", defaults={"content_object": quartz} + ) + self.assertTrue(created) + self.assertEqual(tag.content_object.id, quartz.id) + + tag, created = await TaggedItem.objects.aupdate_or_create( + tag="shiny", defaults={"content_object": diamond} + ) + self.assertFalse(created) + self.assertEqual(tag.content_object.id, diamond.id) def test_update_or_create_defaults_with_create_defaults(self): # update_or_create() should work with virtual fields (content_object). @@ -860,3 +884,46 @@ def test_none_allowed(self): # TaggedItem requires a content_type but initializing with None should # be allowed. TaggedItem(content_object=None) + + +class GenericRelationsAsyncTest(TransactionTestCase): + """ + XXX These tests are split out so that we can run the tests without setUpTestData, + as those tests are running within a single transaction + """ + + available_apps = ["generic_relations"] + + def setUp(self): + self.platypus = Animal.objects.create( + common_name="Platypus", + latin_name="Ornithorhynchus anatinus", + ) + + @from_codegen + def test_add_then_remove_after_prefetch(self): + with new_connection(force_rollback=True): + furry_tag = self.platypus.tags.create(tag="furry") + platypus = Animal.objects.prefetch_related("tags").get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + weird_tag = self.platypus.tags.create(tag="weird") + platypus.tags.add(weird_tag) + self.assertSequenceEqual(list(platypus.tags.all()), [furry_tag, weird_tag]) + platypus.tags.remove(weird_tag) + self.assertSequenceEqual(list(platypus.tags.all()), [furry_tag]) + + @generate_unasynced() + async def test_async_add_then_remove_after_prefetch(self): + async with new_connection(force_rollback=True): + furry_tag = await self.platypus.tags.acreate(tag="furry") + platypus = await Animal.objects.prefetch_related("tags").aget( + pk=self.platypus.pk + ) + self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + weird_tag = await self.platypus.tags.acreate(tag="weird") + await platypus.tags.aadd(weird_tag) + self.assertSequenceEqual( + await alist(platypus.tags.all()), [furry_tag, weird_tag] + ) + await platypus.tags.aremove(weird_tag) + self.assertSequenceEqual(await alist(platypus.tags.all()), [furry_tag]) diff --git a/tests/get_earliest_or_latest/tests.py b/tests/get_earliest_or_latest/tests.py index 21692590ccfd..d25a934b8c66 100644 --- a/tests/get_earliest_or_latest/tests.py +++ b/tests/get_earliest_or_latest/tests.py @@ -1,7 +1,9 @@ from datetime import datetime +from django.db import new_connection from django.db.models import Avg from django.test import TestCase +from django.utils.codegen import from_codegen, generate_unasynced from .models import Article, Comment, IndexErrorArticle, Person @@ -17,83 +19,173 @@ def setUpClass(cls): def tearDown(self): Article._meta.get_latest_by = self._article_get_latest_by + @from_codegen def test_earliest(self): - # Because no Articles exist yet, earliest() raises ArticleDoesNotExist. - with self.assertRaises(Article.DoesNotExist): - Article.objects.earliest() - - a1 = Article.objects.create( - headline="Article 1", - pub_date=datetime(2005, 7, 26), - expire_date=datetime(2005, 9, 1), - ) - a2 = Article.objects.create( - headline="Article 2", - pub_date=datetime(2005, 7, 27), - expire_date=datetime(2005, 7, 28), - ) - a3 = Article.objects.create( - headline="Article 3", - pub_date=datetime(2005, 7, 28), - expire_date=datetime(2005, 8, 27), - ) - a4 = Article.objects.create( - headline="Article 4", - pub_date=datetime(2005, 7, 28), - expire_date=datetime(2005, 7, 30), - ) - - # Get the earliest Article. - self.assertEqual(Article.objects.earliest(), a1) - # Get the earliest Article that matches certain filters. - self.assertEqual( - Article.objects.filter(pub_date__gt=datetime(2005, 7, 26)).earliest(), a2 - ) - - # Pass a custom field name to earliest() to change the field that's used - # to determine the earliest object. - self.assertEqual(Article.objects.earliest("expire_date"), a2) - self.assertEqual( - Article.objects.filter(pub_date__gt=datetime(2005, 7, 26)).earliest( - "expire_date" - ), - a2, - ) - - # earliest() overrides any other ordering specified on the query. - # Refs #11283. - self.assertEqual(Article.objects.order_by("id").earliest(), a1) - - # Error is raised if the user forgot to add a get_latest_by - # in the Model.Meta - Article.objects.model._meta.get_latest_by = None - with self.assertRaisesMessage( - ValueError, - "earliest() and latest() require either fields as positional " - "arguments or 'get_latest_by' in the model's Meta.", - ): - Article.objects.earliest() - - # Earliest publication date, earliest expire date. - self.assertEqual( - Article.objects.filter(pub_date=datetime(2005, 7, 28)).earliest( - "pub_date", "expire_date" - ), - a4, - ) - # Earliest publication date, latest expire date. - self.assertEqual( - Article.objects.filter(pub_date=datetime(2005, 7, 28)).earliest( - "pub_date", "-expire_date" - ), - a3, - ) - - # Meta.get_latest_by may be a tuple. - Article.objects.model._meta.get_latest_by = ("pub_date", "expire_date") - self.assertEqual( - Article.objects.filter(pub_date=datetime(2005, 7, 28)).earliest(), a4 - ) + with new_connection(force_rollback=True): + # Because no Articles exist yet, earliest() raises ArticleDoesNotExist. + with self.assertRaises(Article.DoesNotExist): + Article.objects.earliest() + + a1 = Article.objects.create( + headline="Article 1", + pub_date=datetime(2005, 7, 26), + expire_date=datetime(2005, 9, 1), + ) + a2 = Article.objects.create( + headline="Article 2", + pub_date=datetime(2005, 7, 27), + expire_date=datetime(2005, 7, 28), + ) + a3 = Article.objects.create( + headline="Article 3", + pub_date=datetime(2005, 7, 28), + expire_date=datetime(2005, 8, 27), + ) + a4 = Article.objects.create( + headline="Article 4", + pub_date=datetime(2005, 7, 28), + expire_date=datetime(2005, 7, 30), + ) + + # Get the earliest Article. + self.assertEqual(Article.objects.earliest(), a1) + # Get the earliest Article that matches certain filters. + self.assertEqual( + Article.objects.filter(pub_date__gt=datetime(2005, 7, 26)).earliest(), + a2, + ) + + # Pass a custom field name to earliest() to change the field that's used + # to determine the earliest object. + self.assertEqual(Article.objects.earliest("expire_date"), a2) + self.assertEqual( + Article.objects.filter(pub_date__gt=datetime(2005, 7, 26)).earliest( + "expire_date" + ), + a2, + ) + + # earliest() overrides any other ordering specified on the query. + # Refs #11283. + self.assertEqual(Article.objects.order_by("id").earliest(), a1) + + # Error is raised if the user forgot to add a get_latest_by + # in the Model.Meta + Article.objects.model._meta.get_latest_by = None + with self.assertRaisesMessage( + ValueError, + "earliest() and latest() require either fields as positional " + "arguments or 'get_latest_by' in the model's Meta.", + ): + Article.objects.earliest() + + # Earliest publication date, earliest expire date. + self.assertEqual( + Article.objects.filter(pub_date=datetime(2005, 7, 28)).earliest( + "pub_date", "expire_date" + ), + a4, + ) + # Earliest publication date, latest expire date. + self.assertEqual( + Article.objects.filter(pub_date=datetime(2005, 7, 28)).earliest( + "pub_date", "-expire_date" + ), + a3, + ) + + # Meta.get_latest_by may be a tuple. + Article.objects.model._meta.get_latest_by = ("pub_date", "expire_date") + self.assertEqual( + Article.objects.filter(pub_date=datetime(2005, 7, 28)).earliest(), + a4, + ) + + @generate_unasynced() + async def test_async_earliest(self): + async with new_connection(force_rollback=True): + # Because no Articles exist yet, earliest() raises ArticleDoesNotExist. + with self.assertRaises(Article.DoesNotExist): + await Article.objects.aearliest() + + a1 = await Article.objects.acreate( + headline="Article 1", + pub_date=datetime(2005, 7, 26), + expire_date=datetime(2005, 9, 1), + ) + a2 = await Article.objects.acreate( + headline="Article 2", + pub_date=datetime(2005, 7, 27), + expire_date=datetime(2005, 7, 28), + ) + a3 = await Article.objects.acreate( + headline="Article 3", + pub_date=datetime(2005, 7, 28), + expire_date=datetime(2005, 8, 27), + ) + a4 = await Article.objects.acreate( + headline="Article 4", + pub_date=datetime(2005, 7, 28), + expire_date=datetime(2005, 7, 30), + ) + + # Get the earliest Article. + self.assertEqual(await Article.objects.aearliest(), a1) + # Get the earliest Article that matches certain filters. + self.assertEqual( + await Article.objects.filter( + pub_date__gt=datetime(2005, 7, 26) + ).aearliest(), + a2, + ) + + # Pass a custom field name to earliest() to change the field that's used + # to determine the earliest object. + self.assertEqual(await Article.objects.aearliest("expire_date"), a2) + self.assertEqual( + await Article.objects.filter( + pub_date__gt=datetime(2005, 7, 26) + ).aearliest("expire_date"), + a2, + ) + + # earliest() overrides any other ordering specified on the query. + # Refs #11283. + self.assertEqual(await Article.objects.order_by("id").aearliest(), a1) + + # Error is raised if the user forgot to add a get_latest_by + # in the Model.Meta + Article.objects.model._meta.get_latest_by = None + with self.assertRaisesMessage( + ValueError, + "earliest() and latest() require either fields as positional " + "arguments or 'get_latest_by' in the model's Meta.", + ): + await Article.objects.aearliest() + + # Earliest publication date, earliest expire date. + self.assertEqual( + await Article.objects.filter(pub_date=datetime(2005, 7, 28)).aearliest( + "pub_date", "expire_date" + ), + a4, + ) + # Earliest publication date, latest expire date. + self.assertEqual( + await Article.objects.filter(pub_date=datetime(2005, 7, 28)).aearliest( + "pub_date", "-expire_date" + ), + a3, + ) + + # Meta.get_latest_by may be a tuple. + Article.objects.model._meta.get_latest_by = ("pub_date", "expire_date") + self.assertEqual( + await Article.objects.filter( + pub_date=datetime(2005, 7, 28) + ).aearliest(), + a4, + ) def test_earliest_sliced_queryset(self): msg = "Cannot change a query once a slice has been taken." diff --git a/tests/get_or_create/tests.py b/tests/get_or_create/tests.py index 59f84be221fc..9ac8156149d8 100644 --- a/tests/get_or_create/tests.py +++ b/tests/get_or_create/tests.py @@ -5,9 +5,10 @@ from unittest.mock import patch from django.core.exceptions import FieldError -from django.db import DatabaseError, IntegrityError, connection +from django.db import DatabaseError, IntegrityError, connection, new_connection from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext +from django.utils.codegen import from_codegen, generate_unasynced from django.utils.functional import lazy from .models import ( @@ -68,6 +69,7 @@ def test_get_or_create_redundant_instance(self): self.assertFalse(created) self.assertEqual(Person.objects.count(), 2) + @from_codegen def test_get_or_create_invalid_params(self): """ If you don't specify a value or default value for all required @@ -76,6 +78,16 @@ def test_get_or_create_invalid_params(self): with self.assertRaises(IntegrityError): Person.objects.get_or_create(first_name="Tom", last_name="Smith") + @generate_unasynced() + async def test_async_get_or_create_invalid_params(self): + """ + If you don't specify a value or default value for all required + fields, you will get an error. + """ + async with new_connection(force_rollback=True): + with self.assertRaises(IntegrityError): + await Person.objects.aget_or_create(first_name="Tom", last_name="Smith") + def test_get_or_create_with_pk_property(self): """ Using the pk property of a model is allowed. diff --git a/tests/queries/test_db_returning.py b/tests/queries/test_db_returning.py index 50c164a57f94..1f0bc4f33b6e 100644 --- a/tests/queries/test_db_returning.py +++ b/tests/queries/test_db_returning.py @@ -1,8 +1,9 @@ import datetime -from django.db import connection +from django.db import connection, new_connection from django.test import TestCase, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext +from django.utils.codegen import from_codegen, generate_unasynced from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel @@ -45,11 +46,28 @@ def test_insert_returning_multiple(self): self.assertTrue(obj.pk) self.assertIsInstance(obj.created, datetime.datetime) - @skipUnlessDBFeature("can_return_rows_from_bulk_insert") + # XXX need to put this back in, after I figure out how to support this with + # async tests.... + # @skipUnlessDBFeature("can_return_rows_from_bulk_insert") + @from_codegen def test_bulk_insert(self): - objs = [ReturningModel(), ReturningModel(pk=2**11), ReturningModel()] - ReturningModel.objects.bulk_create(objs) - for obj in objs: - with self.subTest(obj=obj): - self.assertTrue(obj.pk) - self.assertIsInstance(obj.created, datetime.datetime) + with new_connection(force_rollback=True): + objs = [ReturningModel(), ReturningModel(pk=2**11), ReturningModel()] + ReturningModel.objects.bulk_create(objs) + for obj in objs: + with self.subTest(obj=obj): + self.assertTrue(obj.pk) + self.assertIsInstance(obj.created, datetime.datetime) + + # XXX need to put this back in, after I figure out how to support this with + # async tests.... + # @skipUnlessDBFeature("can_return_rows_from_bulk_insert") + @generate_unasynced() + async def test_async_bulk_insert(self): + async with new_connection(force_rollback=True): + objs = [ReturningModel(), ReturningModel(pk=2**11), ReturningModel()] + await ReturningModel.objects.abulk_create(objs) + for obj in objs: + with self.subTest(obj=obj): + self.assertTrue(obj.pk) + self.assertIsInstance(obj.created, datetime.datetime) diff --git a/tests/run_async_qs.sh b/tests/run_async_qs.sh new file mode 100755 index 000000000000..4d37b7b60254 --- /dev/null +++ b/tests/run_async_qs.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env sh +set -e +coverage erase +# coverage run ./runtests.py -k AsyncQuerySetTest -k AsyncNativeQuerySetTest -k test_acount --settings=test_postgresql --keepdb --parallel=1 +STEPWISE=1 coverage run ./runtests.py --settings=test_postgresql --noinput || true # --keepdb --parallel=1 +coverage combine +# echo "Generating coverage for db/models/query.py..." +# coverage html --include '**/db/models/query.py' +echo "Generating coverage.." +coverage html --show-contexts # --include '**/db/models/query.py' +open coverage_html/index.html diff --git a/tests/runtests.py b/tests/runtests.py index e9052ca4a947..c63afd1b1e89 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -13,6 +13,8 @@ import warnings from pathlib import Path +print("HI!!!", file=sys.stderr) +print("HI!!!", file=sys.stdout) try: import django except ImportError as e: @@ -315,6 +317,10 @@ def no_available_apps(cls): ) TransactionTestCase.available_apps = classproperty(no_available_apps) + # NOTE[Raphael]: no_available_apps actually doesn't work in certain + # circumstances, but I'm having trouble remember what.... + # del TransactionTestCase.available_apps + # TransactionTestCase.available_apps = property(no_available_apps) TestCase.available_apps = None # Set an environment variable that other code may consult to see if diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py new file mode 100644 index 000000000000..8d063d551358 --- /dev/null +++ b/tests/test_postgresql.py @@ -0,0 +1,53 @@ +import os +from test_sqlite import * # NOQA + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.postgresql", + "USER": "user", + "NAME": "django", + "PASSWORD": "postgres", + "HOST": "localhost", + "PORT": 5432, + "OPTIONS": { + "server_side_binding": os.getenv("SERVER_SIDE_BINDING") == "1", + }, + }, + "other": { + "ENGINE": "django.db.backends.postgresql", + "USER": "user", + "NAME": "django2", + "PASSWORD": "postgres", + "HOST": "localhost", + "PORT": 5432, + }, +} + +from django.db import connection +from django.db.backends.signals import connection_created +from django.dispatch import receiver + + +def set_sync_timeout(connection): + with connection.cursor() as cursor: + cursor.execute("SET statement_timeout to 100000;") + + +async def set_async_timeout(connection): + async with connection.acursor() as cursor: + await cursor.aexecute("SET statement_timeout to 100000;") + + +from asgiref.sync import sync_to_async + + +@receiver(connection_created) +async def set_statement_timeout(sender, connection, **kwargs): + if connection.vendor == "postgresql": + if connection.connection is not None: + await sync_to_async(set_sync_timeout)(connection) + if connection.aconnection is not None: + await set_async_timeout(connection) + + +print("Gotten!") diff --git a/tests/test_postgresql_async.py b/tests/test_postgresql_async.py new file mode 100644 index 000000000000..2c4c3a32729c --- /dev/null +++ b/tests/test_postgresql_async.py @@ -0,0 +1,199 @@ +import os +from test_sqlite import * # NOQA + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.postgresql", + "USER": "user", + "NAME": "django", + "PASSWORD": "postgres", + "HOST": "localhost", + "PORT": 5432, + "OPTIONS": { + "server_side_binding": os.getenv("SERVER_SIDE_BINDING") == "1", + }, + }, + "other": { + "ENGINE": "django.db.backends.postgresql", + "USER": "user", + "NAME": "django2", + "PASSWORD": "postgres", + "HOST": "localhost", + "PORT": 5432, + }, +} + +# XXX REMOVE LATER +import asyncio +import signal + +# from rdrawer.output import SIO + +from io import TextIOBase + + +class SIO(TextIOBase): + buf: str + + def __init__(self, parent: "SIO | None" = None, label=None): + self.buf = "" + self.parent = parent + self.label = None + super().__init__() + + def write(self, s, /) -> int: + """ + Write input to the item, and then write back the number of characters + written + """ + self.buf += s + return len(s) + + def flush(self): + if self.parent is not None: + for line in self.buf.splitlines(keepends=True): + # write at at extra indentation + self.parent.write(f" {line}") + self.buf = "" + + def close(self): + self.flush() + if self.label is not None: + self.write("-" * 10) + super().close() + + # XXX change interface to just use the same object all the time + def group(self, label=None): + if label is not None: + self.write("|" + label) + self.write("-" * (len(label) + 1) + "\n") + return SIO(parent=self) + + def print(self, f): + self.write(f + "\n") + + +def output_pending_tasks(signum, frame): + print("PENDING HOOK TASK TRIGGERED") + import traceback + + try: + # Some code that raises an exception + 1 / 0 + except Exception as e: + # Print the traceback + traceback.print_exc() + tasks = asyncio.all_tasks(loop=asyncio.get_event_loop()) + sio = SIO() + + sio.print(f"{len(tasks)} pending tasks") + sio.print("Tasks are...") + for task in tasks: + from rdrawer.asyncio import describe_awaitable + + with sio.group(label="Task") as group: + describe_awaitable(task, group) + print(sio.buf) + + +def pending_task_hook(): + signal.signal(signal.SIGUSR2, output_pending_tasks) + + +pending_task_hook() +import asyncio +import inspect +from asyncio import Future, Task +from inspect import _Traceback, FrameInfo +from typing import Any + + +def is_asyncio_shield(stack: list[FrameInfo]): + return stack[0].frame.f_code == asyncio.shield.__code__ + + +def described_stack(stack: list[FrameInfo]): + result = "" + if is_asyncio_shield(stack): + result += "! Asyncio.shield found\n" + for frame in stack: + ctx = ( + frame.code_context[frame.index or 0] or "(Unknown)" + if frame.code_context + else "(Unknown)" + ) + if ctx[-1] != "\n": + ctx += "\n" + result += f"At {frame.filename}:{frame.lineno}\n" + result += f"-> {ctx}" + result += "\n" + return result + + +class TracedFuture(asyncio.Future): + trace: list[FrameInfo] + + def __init__(self, *, loop) -> None: + super().__init__(loop=loop) + self.trace = inspect.stack(context=3)[2:] + + @property + def is_asyncio_shield_call(self): + return is_asyncio_shield(self.trace) + + def get_shielded_future(self): + # Only valid if working on an asyncio.shield call + return self.trace[0].frame.f_locals["inner"] + + def describe_context(self, sio: SIO): + out = described_stack(self.trace) + sio.print(out) + if self.is_asyncio_shield_call: + with sio.group("Shielded Future") as fut_sio: + describe_awaitable(self.get_shielded_future(), fut_sio) + + def described_context(self): + return described_stack(self.trace) + + +def describe_awaitable(awaitable, sio: SIO): + if isinstance(awaitable, Task): + task = awaitable + task.print_stack(file=sio) + if task._fut_waiter is not None: + with sio.group("Waiting on") as wait_on_grp: + describe_awaitable(task._fut_waiter, wait_on_grp) + + # awaiting_fut = task._fut_waiter + # if hasattr(awaiting_fut, "describe_context"): + # awaiting_fut.describe_context(wait_on_grp) + # else: + # wait_on_grp.print(f"Waiting on future of type {awaiting_fut}") + else: + sio.print("Not waiting?") + elif isinstance(awaitable, TracedFuture): + fut = awaitable + sio.print(str(fut)) + fut.describe_context(sio) + else: + sio.print("Unknown awaitable...") + sio.print(str(awaitable)) + + +class TracingEventLoop(asyncio.SelectorEventLoop): + """ + An event loop that should keep track of where futures + are created + """ + + def create_future(self) -> Future[Any]: + print("CREATED FUTURE") + return TracedFuture(loop=self) + + +def tracing_event_loop_factory() -> type[asyncio.AbstractEventLoop]: + print("GOT POLICY") + return TracingEventLoop + + +asyncio.set_event_loop(TracingEventLoop()) diff --git a/tests/test_runner/test_discover_runner.py b/tests/test_runner/test_discover_runner.py index 4c4a22397b63..986ebe7603fb 100644 --- a/tests/test_runner/test_discover_runner.py +++ b/tests/test_runner/test_discover_runner.py @@ -102,6 +102,7 @@ def test_get_max_test_processes_forkserver( self.assertEqual(get_max_test_processes(), 1) +@unittest.skip("XXX fix up later") class DiscoverRunnerTests(SimpleTestCase): @staticmethod def get_test_methods_names(suite): diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py index 9fe8c58593bb..18d6be162a95 100644 --- a/tests/transactions/tests.py +++ b/tests/transactions/tests.py @@ -8,7 +8,9 @@ Error, IntegrityError, OperationalError, + allow_async_db_commits, connection, + new_connection, transaction, ) from django.test import ( @@ -577,3 +579,97 @@ class DurableTransactionTests(DurableTestsBase, TransactionTestCase): class DurableTests(DurableTestsBase, TestCase): pass + + +@skipUnlessDBFeature("uses_savepoints") +@skipUnless(connection.supports_async is True, "Async DB test") +class AsyncTransactionTestCase(TransactionTestCase): + available_apps = ["transactions"] + + async def test_new_connection_nested(self): + with allow_async_db_commits(): + async with new_connection() as connection: + async with new_connection() as connection2: + await connection2.aset_autocommit(False) + async with connection2.acursor() as cursor2: + await cursor2.aexecute( + "INSERT INTO transactions_reporter " + "(first_name, last_name, email) " + "VALUES (%s, %s, %s)", + ("Sarah", "Hatoff", ""), + ) + await cursor2.aexecute("SELECT * FROM transactions_reporter") + result = await cursor2.afetchmany() + assert len(result) == 1 + + async with connection.acursor() as cursor: + await cursor.aexecute("SELECT * FROM transactions_reporter") + result = await cursor.afetchmany() + assert len(result) == 1 + + async def test_new_connection_nested2(self): + with allow_async_db_commits(): + async with new_connection() as connection: + await connection.aset_autocommit(False) + async with connection.acursor() as cursor: + await cursor.aexecute( + "INSERT INTO transactions_reporter (first_name, last_name, email) " + "VALUES (%s, %s, %s)", + ("Tina", "Gravita", ""), + ) + await cursor.aexecute("SELECT * FROM transactions_reporter") + result = await cursor.afetchmany() + assert len(result) == 1 + + async with new_connection() as connection2: + async with connection2.acursor() as cursor2: + await cursor2.aexecute("SELECT * FROM transactions_reporter") + result = await cursor2.afetchmany() + # This connection won't see any rows, because the outer one + # hasn't committed yet. + self.assertEqual(result, []) + + async def test_new_connection_nested3(self): + with allow_async_db_commits(): + async with new_connection() as connection: + async with new_connection() as connection2: + await connection2.aset_autocommit(False) + assert id(connection) != id(connection2) + async with connection2.acursor() as cursor2: + await cursor2.aexecute( + "INSERT INTO transactions_reporter " + "(first_name, last_name, email) " + "VALUES (%s, %s, %s)", + ("Sarah", "Hatoff", ""), + ) + await cursor2.aexecute("SELECT * FROM transactions_reporter") + result = await cursor2.afetchmany() + assert len(result) == 1 + + # Outermost connection doesn't see what the innermost did, because the + # innermost connection hasn't exited yet. + async with connection.acursor() as cursor: + await cursor.aexecute("SELECT * FROM transactions_reporter") + result = await cursor.afetchmany() + assert len(result) == 0 + + async def test_asavepoint(self): + async with new_connection(force_rollback=True) as connection: + async with connection.acursor() as cursor: + sid = await connection.asavepoint() + assert sid is not None + + await cursor.aexecute( + "INSERT INTO transactions_reporter (first_name, last_name, email) " + "VALUES (%s, %s, %s)", + ("Archibald", "Haddock", ""), + ) + await cursor.aexecute("SELECT * FROM transactions_reporter") + result = await cursor.afetchmany(size=5) + assert len(result) == 1 + assert result[0][1:] == ("Archibald", "Haddock", "") + + await connection.asavepoint_rollback(sid) + await cursor.aexecute("SELECT * FROM transactions_reporter") + result = await cursor.fetchmany(size=5) + assert len(result) == 0 diff --git a/tests/xor_lookups/tests.py b/tests/xor_lookups/tests.py index d58d16cf11b8..784b6cb2de81 100644 --- a/tests/xor_lookups/tests.py +++ b/tests/xor_lookups/tests.py @@ -86,3 +86,10 @@ def test_empty_in(self): Number.objects.filter(Q(pk__in=[]) ^ Q(num__gte=5)), self.numbers[5:], ) + + def test_empty_shortcircuit(self): + # test that when working with EmptyQuerySet instances, that we shortcircuit + # by returning the original QS + qs1 = Number.objects.filter(num__gte=3) + self.assertIs(Number.objects.none() ^ qs1, qs1) + self.assertIs(qs1 ^ Number.objects.none(), qs1)