diff --git a/dts_test_project/dts_test_app/migrations/0003_test_add_db_index.py b/dts_test_project/dts_test_app/migrations/0003_test_add_db_index.py index 4f3f6fba..c09cf419 100644 --- a/dts_test_project/dts_test_app/migrations/0003_test_add_db_index.py +++ b/dts_test_project/dts_test_app/migrations/0003_test_add_db_index.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals from django.db import models, migrations -from django.conf import settings class Migration(migrations.Migration): diff --git a/dts_test_project/dts_test_app/migrations/0004_test_alter_unique.py b/dts_test_project/dts_test_app/migrations/0004_test_alter_unique.py index 5f1828aa..d2227441 100644 --- a/dts_test_project/dts_test_app/migrations/0004_test_alter_unique.py +++ b/dts_test_project/dts_test_app/migrations/0004_test_alter_unique.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals from django.db import models, migrations -from django.conf import settings class Migration(migrations.Migration): diff --git a/src/tenant_schemas/apps.py b/src/tenant_schemas/apps.py index 92f5f503..96a5498a 100644 --- a/src/tenant_schemas/apps.py +++ b/src/tenant_schemas/apps.py @@ -97,7 +97,7 @@ def best_practice(app_configs, **kwargs): if not isinstance(default_storage, TenantStorageMixin): errors.append( Warning( - f"Your default storage engine is not tenant aware.", + "Your default storage engine is not tenant aware.", hint="Set settings.STORAGES default backend to " "'tenant_schemas.storage.TenantFileSystemStorage'", id="tenant_schemas.W003", diff --git a/src/tenant_schemas/postgresql_backend/base.py b/src/tenant_schemas/postgresql_backend/base.py index 575969fa..5d814bdd 100644 --- a/src/tenant_schemas/postgresql_backend/base.py +++ b/src/tenant_schemas/postgresql_backend/base.py @@ -1,5 +1,6 @@ import re import warnings +from contextvars import ContextVar from django.conf import settings from django.contrib.contenttypes.models import ContentType @@ -18,14 +19,19 @@ raise ImproperlyConfigured("Error loading psycopg2 or psycopg module") -ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql') +ORIGINAL_BACKEND = getattr( + settings, "ORIGINAL_BACKEND", "django.db.backends.postgresql" +) original_backend = django.db.utils.load_backend(ORIGINAL_BACKEND) -EXTRA_SEARCH_PATHS = getattr(settings, 'PG_EXTRA_SEARCH_PATHS', []) +EXTRA_SEARCH_PATHS = getattr(settings, "PG_EXTRA_SEARCH_PATHS", []) + +# ContextVar to prevent recursion when setting search_path under DEBUG=True with psycopg3 +_SETTING_SEARCH_PATH = ContextVar("ts_setting_search_path", default=False) # from the postgresql doc -SQL_IDENTIFIER_RE = re.compile(r'^[_a-zA-Z][_a-zA-Z0-9]{,62}$') -SQL_SCHEMA_NAME_RESERVED_RE = re.compile(r'^pg_', re.IGNORECASE) +SQL_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][_a-zA-Z0-9]{,62}$") +SQL_SCHEMA_NAME_RESERVED_RE = re.compile(r"^pg_", re.IGNORECASE) def _is_valid_identifier(identifier): @@ -50,6 +56,7 @@ class DatabaseWrapper(original_backend.DatabaseWrapper): """ Adds the capability to manipulate the search_path using set_tenant and set_schema_name """ + include_public_schema = True def __init__(self, *args, **kwargs): @@ -58,16 +65,19 @@ def __init__(self, *args, **kwargs): # Use a patched version of the DatabaseIntrospection that only returns the table list for the # currently selected schema. self.introspection = DatabaseSchemaIntrospection(self) + self._ts_last_path_sig = None # Cache for last applied search path signature self.set_schema_to_public() def close(self): self.search_path_set = False + self._ts_last_path_sig = None # Clear cache on close super().close() def rollback(self): super().rollback() # Django's rollback clears the search path so we have to set it again the next time. self.search_path_set = False + self._ts_last_path_sig = None # Clear cache on rollback def set_tenant(self, tenant, include_public=True): """ @@ -87,6 +97,7 @@ def set_schema(self, schema_name, include_public=True): self.include_public_schema = include_public self.set_settings_schema(schema_name) self.search_path_set = False + self._ts_last_path_sig = None # Clear cache when schema changes # Content type can no longer be cached as public and tenant schemas # have different models. If someone wants to change this, the cache # needs to be separated between public and shared schemas. If this @@ -103,18 +114,47 @@ def set_schema_to_public(self): self.set_schema(get_public_schema_name()) def set_settings_schema(self, schema_name): - self.settings_dict['SCHEMA'] = schema_name + self.settings_dict["SCHEMA"] = schema_name def get_schema(self): - warnings.warn("connection.get_schema() is deprecated, use connection.schema_name instead.", - category=DeprecationWarning) + warnings.warn( + "connection.get_schema() is deprecated, use connection.schema_name instead.", + category=DeprecationWarning, + ) return self.schema_name def get_tenant(self): - warnings.warn("connection.get_tenant() is deprecated, use connection.tenant instead.", - category=DeprecationWarning) + warnings.warn( + "connection.get_tenant() is deprecated, use connection.tenant instead.", + category=DeprecationWarning, + ) return self.tenant + def _should_set_search_path(self, path_sig): + """ + Determine if search_path needs to be set based on current configuration. + + Returns True if: + - Limit set calls is disabled OR search_path is not set + - AND the path signature has changed + """ + return ( + not get_limit_set_calls() or not self.search_path_set + ) and self._ts_last_path_sig != path_sig + + def _get_raw_cursor(self, cursor_for_search_path): + """ + Get the raw DB-API cursor for psycopg2/psycopg3 compatibility. + + In psycopg2, cursor_for_search_path may have a 'cursor' attribute + pointing to the raw DB-API cursor. + In psycopg3, the cursor object itself is the raw DB-API cursor. + """ + if hasattr(cursor_for_search_path, "cursor"): + return cursor_for_search_path.cursor + else: + return cursor_for_search_path + def _cursor(self, name=None): """ Here it happens. We hope every Django db operation using PostgreSQL @@ -126,56 +166,90 @@ def _cursor(self, name=None): else: cursor = super()._cursor() - # optionally limit the number of executions - under load, the execution - # of `set search_path` can be quite time consuming - if (not get_limit_set_calls()) or not self.search_path_set: - # Actual search_path modification for the cursor. Database will - # search schemata from left to right when looking for the object - # (table, index, sequence, etc.). - if not self.schema_name: - raise ImproperlyConfigured("Database schema not set. Did you forget " - "to call set_schema() or set_tenant()?") - _check_schema_name(self.schema_name) - public_schema_name = get_public_schema_name() - search_paths = [] - - if self.schema_name == public_schema_name: - search_paths = [public_schema_name] - elif self.include_public_schema: - search_paths = [self.schema_name, public_schema_name] - else: - search_paths = [self.schema_name] - - search_paths.extend(EXTRA_SEARCH_PATHS) - - if name: - # Named cursor can only be used once - cursor_for_search_path = self.connection.cursor() - else: - # Reuse - cursor_for_search_path = cursor - - # In the event that an error already happened in this transaction and we are going - # to rollback we should just ignore database error when setting the search_path - # if the next instruction is not a rollback it will just fail also, so - # we do not have to worry that it's not the good one - try: - cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths))) - except (django.db.utils.DatabaseError, InternalError): - self.search_path_set = False - else: - self.search_path_set = True + # Calculate search paths for current tenant configuration + if not self.schema_name: + raise ImproperlyConfigured( + "Database schema not set. Did you forget " + "to call set_schema() or set_tenant()?" + ) + + _check_schema_name(self.schema_name) + public_schema_name = get_public_schema_name() + search_paths = [] + + if self.schema_name == public_schema_name: + search_paths = [public_schema_name] + elif self.include_public_schema: + search_paths = [self.schema_name, public_schema_name] + else: + search_paths = [self.schema_name] + + search_paths.extend(EXTRA_SEARCH_PATHS) + path_sig = tuple(search_paths) - if name: - cursor_for_search_path.close() + # Check if we need to set the search path + if self._should_set_search_path(path_sig): + # Prevent recursion during debug/mogrify operations with psycopg3 + if _SETTING_SEARCH_PATH.get(): + return cursor + + token = _SETTING_SEARCH_PATH.set(True) + try: + if name: + # Named cursor can only be used once + cursor_for_search_path = self.connection.cursor() + else: + # Reuse - get raw cursor to avoid Django's debug wrapper + cursor_for_search_path = cursor + raw_cursor = self._get_raw_cursor(cursor_for_search_path) + + # In the event that an error already happened in this transaction and we are going + # to rollback we should just ignore database error when setting the search_path + # if the next instruction is not a rollback it will just fail also, so + # we do not have to worry that it's not the good one + try: + # Use set_config with parameters instead of raw SQL formatting to avoid + # triggering Django's debug SQL logging that causes psycopg3 recursion + if name: + cursor_for_search_path.execute( + "SELECT set_config('search_path', %s, false)", + [",".join(search_paths)], + ) + else: + raw_cursor.execute( + "SELECT set_config('search_path', %s, false)", + [",".join(search_paths)], + ) + except (django.db.utils.DatabaseError, InternalError): + self.search_path_set = False + self._ts_last_path_sig = None + else: + self.search_path_set = True + self._ts_last_path_sig = path_sig + + if name: + cursor_for_search_path.close() + finally: + _SETTING_SEARCH_PATH.reset(token) return cursor + def last_executed_query(self, cursor, sql, params): + """ + Override to avoid opening a fresh cursor during mogrify when there are no params. + This helps prevent recursion issues with psycopg3 when DEBUG=True. + """ + if not params: # no need to mogrify, avoids opening a fresh cursor + return sql + # Delegate to the operations class + return self.ops.last_executed_query(cursor, sql, params) + class FakeTenant: """ We can't import any db model in a backend (apparently?), so this class is used for wrapping schema names in a tenant-like structure. """ + def __init__(self, schema_name): self.schema_name = schema_name diff --git a/src/tenant_schemas/tests/test_apps.py b/src/tenant_schemas/tests/test_apps.py index a9c9cb5a..83fa8dbf 100644 --- a/src/tenant_schemas/tests/test_apps.py +++ b/src/tenant_schemas/tests/test_apps.py @@ -72,7 +72,7 @@ def test_storage_engines(self): self.assertBestPractice( [ Warning( - f"Your default storage engine is not tenant aware.", + "Your default storage engine is not tenant aware.", hint="Set settings.STORAGES default backend to " "'tenant_schemas.storage.TenantFileSystemStorage'", id="tenant_schemas.W003", diff --git a/src/tenant_schemas/tests/test_psycopg3_recursion.py b/src/tenant_schemas/tests/test_psycopg3_recursion.py new file mode 100644 index 00000000..c00f7cc4 --- /dev/null +++ b/src/tenant_schemas/tests/test_psycopg3_recursion.py @@ -0,0 +1,215 @@ +from unittest.mock import patch, Mock +import django.db.utils + +from django.db import connection +from django.test import override_settings +from tenant_schemas.tests.testcases import BaseTestCase +from tenant_schemas.tests.models import Tenant +from tenant_schemas.utils import get_public_schema_name +from tenant_schemas.postgresql_backend.base import _SETTING_SEARCH_PATH, DatabaseWrapper + + +class Psycopg3RecursionFixTest(BaseTestCase): + """ + Tests for the psycopg3 recursion fix when DEBUG=True. + + The bug occurs when Django's SQL debug logging tries to format queries + using psycopg3's mogrify, which opens a new cursor, causing recursion + in our _cursor() method when setting search_path. + """ + + # Cache the original backend to avoid repeated loading in tests + _original_backend = django.db.utils.load_backend("django.db.backends.postgresql") + + @classmethod + @override_settings( + SHARED_APPS=("tenant_schemas",), + TENANT_APPS=( + "dts_test_app", + "django.contrib.contenttypes", + "django.contrib.auth", + ), + INSTALLED_APPS=( + "tenant_schemas", + "dts_test_app", + "django.contrib.contenttypes", + "django.contrib.auth", + ), + ) + def setUpClass(cls): + super().setUpClass() + cls.sync_shared() + Tenant(domain_url="test.com", schema_name=get_public_schema_name()).save( + verbosity=cls.get_verbosity() + ) + + def setUp(self): + super().setUp() + # Create a mock tenant for unit tests - no DB operations needed + test_name = self._testMethodName + self.tenant = Mock() + self.tenant.domain_url = f"{test_name}.test.com" + self.tenant.schema_name = f"test_{test_name}" + + def _create_test_wrapper(self): + """Create a test DatabaseWrapper with proper settings.""" + return DatabaseWrapper( + { + "ENGINE": "tenant_schemas.postgresql_backend", + "NAME": "test_db", + "HOST": "localhost", + "PORT": 5432, + "USER": "test", + "PASSWORD": "test", + "TIME_ZONE": "UTC", + "CONN_MAX_AGE": 0, + "CONN_HEALTH_CHECKS": False, + "AUTOCOMMIT": True, + "ATOMIC_REQUESTS": False, + "OPTIONS": {}, + "TEST": {}, + } + ) + + def test_contextvar_prevents_recursion(self): + """Test that ContextVar prevents recursion during search_path setting.""" + wrapper = self._create_test_wrapper() + wrapper.set_tenant(self.tenant) + + # Simulate being already in the middle of setting search_path + token = _SETTING_SEARCH_PATH.set(True) + try: + # Mock the super()._cursor() call and connection to avoid actual DB calls + # Use the cached original backend DatabaseWrapper that our class inherits from + with patch.object( + self._original_backend.DatabaseWrapper, "_cursor" + ) as mock_super_cursor, patch.object(wrapper, "connection") as mock_conn: + mock_super_cursor.return_value = Mock() + mock_conn.cursor.return_value = Mock() + + cursor = wrapper._cursor() + self.assertIsNotNone(cursor) + + # Should have called super()._cursor() to get the cursor + mock_super_cursor.assert_called_once() + # But should NOT have called connection.cursor() due to ContextVar guard + mock_conn.cursor.assert_not_called() + finally: + _SETTING_SEARCH_PATH.reset(token) + + def test_search_path_signature_caching(self): + """Test that search_path signature caching works correctly.""" + wrapper = self._create_test_wrapper() + wrapper.set_tenant(self.tenant) + + # Mock connection to count calls + with patch.object(wrapper, "connection") as mock_conn: + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + + # First call should set search_path + cursor1 = wrapper._cursor() + self.assertIsNotNone(cursor1) + first_call_count = mock_cursor.execute.call_count + + # Second call with same tenant should use cache + cursor2 = wrapper._cursor() + self.assertIsNotNone(cursor2) + second_call_count = mock_cursor.execute.call_count + + # Should not have called execute again (cached) + self.assertEqual(first_call_count, second_call_count) + + def test_tenant_switch_clears_cache(self): + """Test that switching tenants clears the path signature cache.""" + wrapper = self._create_test_wrapper() + + # Set initial tenant and cache a signature + wrapper.set_tenant(self.tenant) + wrapper._ts_last_path_sig = ("tenant1", "public") + + # Switch tenants - should clear cache + wrapper.set_schema_to_public() + self.assertIsNone(wrapper._ts_last_path_sig) + + def test_connection_close_clears_cache(self): + """Test that closing connection clears the cache.""" + wrapper = self._create_test_wrapper() + + # Set cache + wrapper._ts_last_path_sig = ("tenant1", "public") + + # Close should clear cache + with patch.object(wrapper, "connection"): # Mock to avoid real close + wrapper.close() + self.assertIsNone(wrapper._ts_last_path_sig) + + def test_rollback_clears_cache(self): + """Test that rollback clears the cache.""" + wrapper = self._create_test_wrapper() + + # Set cache + wrapper._ts_last_path_sig = ("tenant1", "public") + + # Rollback should clear cache + # Use the cached original backend DatabaseWrapper that our class inherits from + with patch.object( + self._original_backend.DatabaseWrapper, "rollback" + ): # Mock parent rollback + wrapper.rollback() + self.assertIsNone(wrapper._ts_last_path_sig) + + def test_last_executed_query_optimization(self): + """Test that last_executed_query skips mogrify for parameterless queries.""" + wrapper = self._create_test_wrapper() + + cursor = Mock() + + # Test with no parameters - should return SQL as-is + result = wrapper.last_executed_query(cursor, "SELECT 1", None) + self.assertEqual(result, "SELECT 1") + + result = wrapper.last_executed_query(cursor, "SELECT 1", []) + self.assertEqual(result, "SELECT 1") + + # Test with parameters - should delegate to ops + with patch.object( + wrapper.ops, "last_executed_query", return_value="formatted" + ) as mock_ops: + result = wrapper.last_executed_query(cursor, "SELECT %s", ["test"]) + self.assertEqual(result, "formatted") + mock_ops.assert_called_once_with(cursor, "SELECT %s", ["test"]) + + @override_settings(DEBUG=True) + def test_integration_with_real_connection(self): + """Integration test with real database connection and DEBUG=True.""" + # Use existing public schema to avoid tenant creation complexity + connection.set_schema_to_public() + + try: + # This should work without recursion error even with DEBUG=True + cursor = connection._cursor() + self.assertIsNotNone(cursor) + + # Execute a simple query to verify everything works + cursor.execute("SELECT 1") + result = cursor.fetchone() + self.assertEqual(result[0], 1) + + # Verify we can switch search paths without issues + # Create a mock tenant for search path testing + mock_tenant = Mock() + mock_tenant.schema_name = "public" # Use public schema that exists + + connection.set_tenant(mock_tenant) + cursor2 = connection._cursor() + self.assertIsNotNone(cursor2) + + # Another simple query to ensure search_path changes work + cursor2.execute("SELECT 2") + result2 = cursor2.fetchone() + self.assertEqual(result2[0], 2) + + finally: + # Clean up any state changes + connection.set_schema_to_public() diff --git a/src/tenant_schemas/tests/test_routes.py b/src/tenant_schemas/tests/test_routes.py index 0a9b8b7c..760a43da 100644 --- a/src/tenant_schemas/tests/test_routes.py +++ b/src/tenant_schemas/tests/test_routes.py @@ -1,6 +1,3 @@ -import unittest - -from django.conf import settings from django.core.exceptions import DisallowedHost from django.http import Http404 from django.test.client import RequestFactory @@ -22,13 +19,6 @@ class RoutesTestCase(BaseTestCase): @classmethod def setUpClass(cls): super().setUpClass() - settings.SHARED_APPS = ("tenant_schemas",) - settings.TENANT_APPS = ( - "dts_test_app", - "django.contrib.contenttypes", - "django.contrib.auth", - ) - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS cls.sync_shared() cls.public_tenant = Tenant( domain_url="test.com", schema_name=get_public_schema_name() diff --git a/src/tenant_schemas/tests/test_tenants.py b/src/tenant_schemas/tests/test_tenants.py index b2e57cbc..bfbe669b 100644 --- a/src/tenant_schemas/tests/test_tenants.py +++ b/src/tenant_schemas/tests/test_tenants.py @@ -1,6 +1,6 @@ -from django.conf import settings from django.contrib.auth.models import User from django.db import connection +from django.test import override_settings from dts_test_app.models import DummyModel, ModelWithFkToPublicUser from tenant_schemas.management.commands import tenant_command from tenant_schemas.test.cases import TenantTestCase @@ -26,13 +26,6 @@ class TenantDataAndSettingsTest(BaseTestCase): @classmethod def setUpClass(cls): super().setUpClass() - settings.SHARED_APPS = ("tenant_schemas",) - settings.TENANT_APPS = ( - "dts_test_app", - "django.contrib.contenttypes", - "django.contrib.auth", - ) - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS cls.sync_shared() Tenant(domain_url="test.com", schema_name=get_public_schema_name()).save( verbosity=cls.get_verbosity() @@ -178,36 +171,50 @@ class TenantSyncTest(BaseTestCase): MIGRATION_TABLE_SIZE = 1 + @override_settings( + SHARED_APPS=( + "tenant_schemas", # 2 tables + "django.contrib.auth", # 6 tables + "django.contrib.contenttypes", + ), # 1 table + TENANT_APPS=("django.contrib.sessions",), + INSTALLED_APPS=( + "tenant_schemas", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + ), + ) def test_shared_apps_does_not_sync_tenant_apps(self): """ Tests that if an app is in SHARED_APPS, it does not get synced to the a tenant schema. """ - settings.SHARED_APPS = ( - "tenant_schemas", # 2 tables - "django.contrib.auth", # 6 tables - "django.contrib.contenttypes", - ) # 1 table - settings.TENANT_APPS = ("django.contrib.sessions",) - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS self.sync_shared() shared_tables = self.get_tables_list_in_schema(get_public_schema_name()) self.assertEqual(2 + 6 + 1 + self.MIGRATION_TABLE_SIZE, len(shared_tables)) self.assertNotIn("django_session", shared_tables) + @override_settings( + SHARED_APPS=( + "tenant_schemas", + "django.contrib.auth", + "django.contrib.contenttypes", + ), + TENANT_APPS=("django.contrib.sessions",), # 1 table + INSTALLED_APPS=( + "tenant_schemas", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + ), + ) def test_tenant_apps_does_not_sync_shared_apps(self): """ Tests that if an app is in TENANT_APPS, it does not get synced to the public schema. """ - settings.SHARED_APPS = ( - "tenant_schemas", - "django.contrib.auth", - "django.contrib.contenttypes", - ) - settings.TENANT_APPS = ("django.contrib.sessions",) # 1 table - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS self.sync_shared() tenant = Tenant(domain_url="arbitrary.test.com", schema_name="test") tenant.save(verbosity=BaseTestCase.get_verbosity()) @@ -216,19 +223,26 @@ def test_tenant_apps_does_not_sync_shared_apps(self): self.assertEqual(1 + self.MIGRATION_TABLE_SIZE, len(tenant_tables)) self.assertIn("django_session", tenant_tables) + @override_settings( + SHARED_APPS=( + "tenant_schemas", # 2 tables + "django.contrib.auth", # 6 tables + "django.contrib.contenttypes", # 1 table + "django.contrib.sessions", + ), # 1 table + TENANT_APPS=("django.contrib.sessions",), # 1 table + INSTALLED_APPS=( + "tenant_schemas", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + ), + ) def test_tenant_apps_and_shared_apps_can_have_the_same_apps(self): """ Tests that both SHARED_APPS and TENANT_APPS can have apps in common. In this case they should get synced to both tenant and public schemas. """ - settings.SHARED_APPS = ( - "tenant_schemas", # 2 tables - "django.contrib.auth", # 6 tables - "django.contrib.contenttypes", # 1 table - "django.contrib.sessions", - ) # 1 table - settings.TENANT_APPS = ("django.contrib.sessions",) # 1 table - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS self.sync_shared() tenant = Tenant(domain_url="arbitrary.test.com", schema_name="test") tenant.save(verbosity=BaseTestCase.get_verbosity()) @@ -240,17 +254,23 @@ def test_tenant_apps_and_shared_apps_can_have_the_same_apps(self): self.assertEqual(1 + self.MIGRATION_TABLE_SIZE, len(tenant_tables)) self.assertIn("django_session", tenant_tables) + @override_settings( + SHARED_APPS=( + "tenant_schemas", # 2 tables + "django.contrib.contenttypes", + ), # 1 table + TENANT_APPS=("django.contrib.sessions",), # 1 table + INSTALLED_APPS=( + "tenant_schemas", + "django.contrib.contenttypes", + "django.contrib.sessions", + ), + ) def test_content_types_is_not_mandatory(self): """ Tests that even if content types is in SHARED_APPS, it's not required in TENANT_APPS. """ - settings.SHARED_APPS = ( - "tenant_schemas", # 2 tables - "django.contrib.contenttypes", - ) # 1 table - settings.TENANT_APPS = ("django.contrib.sessions",) # 1 table - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS self.sync_shared() tenant = Tenant(domain_url="something.test.com", schema_name="test") tenant.save(verbosity=BaseTestCase.get_verbosity()) @@ -264,17 +284,22 @@ def test_content_types_is_not_mandatory(self): class TenantCommandTest(BaseTestCase): + @override_settings( + SHARED_APPS=( + "tenant_schemas", + "django.contrib.contenttypes", + ), + TENANT_APPS=(), + INSTALLED_APPS=( + "tenant_schemas", + "django.contrib.contenttypes", + ), + ) def test_command(self): """ Tests that tenant_command is capable of wrapping commands and its parameters. """ - settings.SHARED_APPS = ( - "tenant_schemas", - "django.contrib.contenttypes", - ) - settings.TENANT_APPS = () - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS self.sync_shared() Tenant(domain_url="localhost", schema_name="public").save( verbosity=BaseTestCase.get_verbosity() @@ -300,17 +325,24 @@ def test_command(self): ) +@override_settings( + SHARED_APPS=( + "tenant_schemas", + "django.contrib.auth", + "django.contrib.contenttypes", + ), + TENANT_APPS=("dts_test_app",), + INSTALLED_APPS=( + "tenant_schemas", + "django.contrib.auth", + "django.contrib.contenttypes", + "dts_test_app", + ), +) class SharedAuthTest(BaseTestCase): @classmethod def setUpClass(cls): super().setUpClass() - settings.SHARED_APPS = ( - "tenant_schemas", - "django.contrib.auth", - "django.contrib.contenttypes", - ) - settings.TENANT_APPS = ("dts_test_app",) - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS cls.sync_shared() Tenant(domain_url="test.com", schema_name=get_public_schema_name()).save( verbosity=cls.get_verbosity() diff --git a/src/tenant_schemas/tests/testcases.py b/src/tenant_schemas/tests/testcases.py index 2bcbf438..aa6175ed 100644 --- a/src/tenant_schemas/tests/testcases.py +++ b/src/tenant_schemas/tests/testcases.py @@ -3,42 +3,52 @@ from django.conf import settings from django.core.management import call_command from django.db import connection -from django.test import TestCase +from django.test import TestCase, override_settings from tenant_schemas.utils import get_public_schema_name +@override_settings( + TENANT_MODEL="tenant_schemas.Tenant", + SHARED_APPS=("tenant_schemas",), + TENANT_APPS=("dts_test_app", "django.contrib.contenttypes", "django.contrib.auth"), + INSTALLED_APPS=( + "tenant_schemas", + "dts_test_app", + "django.contrib.contenttypes", + "django.contrib.auth", + ), +) class BaseTestCase(TestCase): """ Base test case that comes packed with overloaded INSTALLED_APPS, custom public tenant, and schemas cleanup on tearDown. """ + @classmethod def setUpClass(cls): - settings.TENANT_MODEL = 'tenant_schemas.Tenant' - settings.SHARED_APPS = ('tenant_schemas', ) - settings.TENANT_APPS = ('dts_test_app', - 'django.contrib.contenttypes', - 'django.contrib.auth', ) - settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS - if '.test.com' not in settings.ALLOWED_HOSTS: - settings.ALLOWED_HOSTS += ['.test.com'] + # Add .test.com to ALLOWED_HOSTS for testing + cls.original_allowed_hosts = list(settings.ALLOWED_HOSTS) + if ".test.com" not in settings.ALLOWED_HOSTS: + settings.ALLOWED_HOSTS = cls.original_allowed_hosts + [".test.com"] # Django calls syncdb by default for the test database, but we want # a blank public schema for this set of tests. connection.set_schema_to_public() cursor = connection.cursor() - cursor.execute('DROP SCHEMA IF EXISTS %s CASCADE; CREATE SCHEMA %s;' - % (get_public_schema_name(), get_public_schema_name())) + cursor.execute( + "DROP SCHEMA IF EXISTS %s CASCADE; CREATE SCHEMA %s;" + % (get_public_schema_name(), get_public_schema_name()) + ) super().setUpClass() @classmethod def tearDownClass(cls): + # Restore original ALLOWED_HOSTS + if hasattr(cls, "original_allowed_hosts"): + settings.ALLOWED_HOSTS = cls.original_allowed_hosts super().tearDownClass() - if '.test.com' in settings.ALLOWED_HOSTS: - settings.ALLOWED_HOSTS.remove('.test.com') - def setUp(self): connection.set_schema_to_public() super().setUp() @@ -46,9 +56,9 @@ def setUp(self): @classmethod def get_verbosity(self): for s in reversed(inspect.stack()): - options = s[0].f_locals.get('options') + options = s[0].f_locals.get("options") if isinstance(options, dict): - return int(options['verbosity']) - 2 + return int(options["verbosity"]) - 2 return 1 @classmethod @@ -56,13 +66,15 @@ def get_tables_list_in_schema(cls, schema_name): cursor = connection.cursor() sql = """SELECT table_name FROM information_schema.tables WHERE table_schema = %s""" - cursor.execute(sql, (schema_name, )) + cursor.execute(sql, (schema_name,)) return [row[0] for row in cursor.fetchall()] @classmethod def sync_shared(cls): - call_command('migrate_schemas', - schema_name=get_public_schema_name(), - interactive=False, - verbosity=cls.get_verbosity(), - run_syncdb=True) + call_command( + "migrate_schemas", + schema_name=get_public_schema_name(), + interactive=False, + verbosity=cls.get_verbosity(), + run_syncdb=True, + )