Skip to content

Commit 897c474

Browse files
authored
allow custom user models in rename_groups migration (#59)
* allow custom user models in rename_groups migration * fix lint * update version number
1 parent 0169594 commit 897c474

4 files changed

Lines changed: 158 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ x.y.z (UNRELEASED)
55
------------------
66
* Changes
77

8+
1.0.6 (2025-12-28)
9+
------------------
10+
* Add support for custom user model
11+
812
1.0.2 (2024-04-26)
913
------------------
1014
* Release Analytics Library

accounts/migrations/rename_groups.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from django.conf import settings
12
from django.db import migrations
23

34

45
def forwards_func(apps, schema_editor):
56
Group = apps.get_model("auth", "Group")
6-
User = apps.get_model("auth", "User")
7+
User = apps.get_model(settings.AUTH_USER_MODEL)
78
platform_groups = ["alum", "employee", "faculty", "member", "staff", "student"]
89

910
# Create platform scoped groups
@@ -23,5 +24,8 @@ def forwards_func(apps, schema_editor):
2324

2425

2526
class Migration(migrations.Migration):
26-
dependencies = [("accounts", "0001_initial")]
27+
dependencies = [
28+
("accounts", "0001_initial"),
29+
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
30+
]
2731
operations = [migrations.RunPython(forwards_func, None)]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "django-labs-accounts"
3-
version = "1.0.5"
3+
version = "1.0.6"
44
description = "Reusable Django app for Penn Labs accounts"
55
authors = ["Penn Labs <contact@pennlabs.org>"]
66
license = "MIT"

tests/accounts/test_migrations.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from django.apps import apps
2+
from django.conf import settings
3+
from django.contrib.auth import get_user_model
4+
from django.contrib.auth.models import Group
5+
from django.test import TestCase
6+
7+
from accounts.migrations.rename_groups import forwards_func
8+
9+
10+
class RenameGroupsMigrationTestCase(TestCase):
11+
def setUp(self):
12+
User = get_user_model()
13+
14+
self.user1 = User.objects.create_user(username="user1", email="user1@test.com")
15+
self.user2 = User.objects.create_user(username="user2", email="user2@test.com")
16+
17+
self.student_group = Group.objects.create(name="student")
18+
self.staff_group = Group.objects.create(name="staff")
19+
self.member_group = Group.objects.create(name="member")
20+
21+
self.user1.groups.add(self.student_group, self.member_group)
22+
self.user2.groups.add(self.staff_group)
23+
24+
def test_migration_function_renames_groups_correctly(self):
25+
class MockSchemaEditor:
26+
pass
27+
28+
forwards_func(apps, MockSchemaEditor())
29+
30+
# old groups should be gone
31+
self.assertFalse(Group.objects.filter(name="student").exists())
32+
self.assertFalse(Group.objects.filter(name="staff").exists())
33+
self.assertFalse(Group.objects.filter(name="member").exists())
34+
35+
# new platform_ groups should exist
36+
self.assertTrue(Group.objects.filter(name="platform_student").exists())
37+
self.assertTrue(Group.objects.filter(name="platform_staff").exists())
38+
self.assertTrue(Group.objects.filter(name="platform_member").exists())
39+
40+
def test_migration_updates_user_group_memberships(self):
41+
class MockSchemaEditor:
42+
pass
43+
44+
forwards_func(apps, MockSchemaEditor())
45+
46+
self.user1.refresh_from_db()
47+
self.user2.refresh_from_db()
48+
49+
user1_groups = set(self.user1.groups.values_list("name", flat=True))
50+
user2_groups = set(self.user2.groups.values_list("name", flat=True))
51+
52+
# users should have new platform_ groups
53+
self.assertIn("platform_student", user1_groups)
54+
self.assertIn("platform_member", user1_groups)
55+
self.assertIn("platform_staff", user2_groups)
56+
57+
# old groups should be removed
58+
self.assertNotIn("student", user1_groups)
59+
self.assertNotIn("member", user1_groups)
60+
self.assertNotIn("staff", user2_groups)
61+
62+
def test_migration_creates_all_required_platform_groups(self):
63+
platform_groups = ["alum", "employee", "faculty", "member", "staff", "student"]
64+
65+
class MockSchemaEditor:
66+
pass
67+
68+
forwards_func(apps, MockSchemaEditor())
69+
70+
for group_name in platform_groups:
71+
self.assertTrue(
72+
Group.objects.filter(name=f"platform_{group_name}").exists(),
73+
f"platform_{group_name} group should exist after migration",
74+
)
75+
76+
def test_migration_uses_auth_user_model_setting(self):
77+
# migration should use AUTH_USER_MODEL instead of hardcoded auth.User
78+
User = apps.get_model(settings.AUTH_USER_MODEL)
79+
80+
self.assertIsNotNone(User)
81+
self.assertTrue(hasattr(User, "username"))
82+
self.assertTrue(hasattr(User, "email"))
83+
self.assertTrue(hasattr(User, "groups"))
84+
85+
def test_migration_handles_users_without_platform_groups(self):
86+
User = get_user_model()
87+
88+
user3 = User.objects.create_user(username="user3", email="user3@test.com")
89+
other_group = Group.objects.create(name="other_group")
90+
user3.groups.add(other_group)
91+
92+
class MockSchemaEditor:
93+
pass
94+
95+
forwards_func(apps, MockSchemaEditor())
96+
97+
user3.refresh_from_db()
98+
user3_groups = set(user3.groups.values_list("name", flat=True))
99+
self.assertIn("other_group", user3_groups)
100+
101+
def test_migration_is_idempotent(self):
102+
# running migration multiple times shouldn't break things
103+
class MockSchemaEditor:
104+
pass
105+
106+
forwards_func(apps, MockSchemaEditor())
107+
forwards_func(apps, MockSchemaEditor())
108+
109+
platform_groups = ["alum", "employee", "faculty", "member", "staff", "student"]
110+
for group_name in platform_groups:
111+
count = Group.objects.filter(name=f"platform_{group_name}").count()
112+
self.assertEqual(count, 1)
113+
114+
def test_migration_with_empty_database(self):
115+
User = get_user_model()
116+
User.objects.all().delete()
117+
Group.objects.all().delete()
118+
119+
class MockSchemaEditor:
120+
pass
121+
122+
forwards_func(apps, MockSchemaEditor())
123+
124+
platform_groups = ["alum", "employee", "faculty", "member", "staff", "student"]
125+
for group_name in platform_groups:
126+
self.assertTrue(
127+
Group.objects.filter(name=f"platform_{group_name}").exists()
128+
)
129+
130+
131+
class CustomUserModelCompatibilityTest(TestCase):
132+
def test_get_model_with_auth_user_model_setting(self):
133+
# checks if apps.get_model works with AUTH_USER_MODEL
134+
User = apps.get_model(settings.AUTH_USER_MODEL)
135+
136+
self.assertIsNotNone(User)
137+
self.assertTrue(hasattr(User, "objects"))
138+
self.assertEqual(User, get_user_model())
139+
140+
def test_get_model_accepts_dotted_string(self):
141+
# both formats should give us the same model
142+
User1 = apps.get_model(settings.AUTH_USER_MODEL)
143+
144+
app_label, model_name = settings.AUTH_USER_MODEL.split(".")
145+
User2 = apps.get_model(app_label, model_name)
146+
147+
self.assertEqual(User1, User2)

0 commit comments

Comments
 (0)