Skip to content

Commit d86c6de

Browse files
feat: extend update_fields with translation fields in Model.save() (#687)
1 parent c68104c commit d86c6de

File tree

6 files changed

+184
-54
lines changed

6 files changed

+184
-54
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ jobs:
8888
if [[ $DB == postgres ]]; then
8989
pip install -q psycopg2-binary
9090
fi
91-
pip install typing-extensions coverage pytest pytest-django pytest-cov $(./get-django-version.py ${{ matrix.django }})
91+
pip install typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }})
9292
- name: Run tests
9393
run: |
9494
pytest --cov-report term

modeltranslation/manager.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
"""
77
import itertools
88
from functools import reduce
9+
from typing import List, Tuple, Type, Any, Optional
910

1011
from django import VERSION
1112
from django.contrib.admin.utils import get_model_from_relation
1213
from django.core.exceptions import FieldDoesNotExist
1314
from django.db import models
15+
from django.db.models import Field, Model
1416
from django.db.models.expressions import Col
1517
from django.db.models.lookups import Lookup
1618
from django.db.models.query import QuerySet, ValuesIterable
@@ -243,21 +245,6 @@ def select_related(self, *fields, **kwargs):
243245
new_args.append(rewrite_lookup_key(self.model, key))
244246
return super().select_related(*new_args, **kwargs)
245247

246-
def update_or_create(self, defaults=None, **kwargs):
247-
"""
248-
Updates or creates a database record with the specified kwargs. The method first
249-
rewrites the keys in the defaults dictionary using a custom function named
250-
`rewrite_lookup_key`. This ensures that the keys are valid for the current model
251-
before calling the inherited update_or_create() method from the super class.
252-
Returns the updated or created model instance.
253-
"""
254-
if defaults is not None:
255-
rewritten_defaults = {}
256-
for key, value in defaults.items():
257-
rewritten_defaults[rewrite_lookup_key(self.model, key)] = value
258-
defaults = rewritten_defaults
259-
return super().update_or_create(defaults=defaults, **kwargs)
260-
261248
# This method was not present in django-linguo
262249
def _rewrite_col(self, col):
263250
"""Django >= 1.7 column name rewriting"""
@@ -386,6 +373,27 @@ def update(self, **kwargs):
386373

387374
update.alters_data = True
388375

376+
def _update(self, values: List[Tuple[Field, Optional[Type[Model]], Any]]):
377+
"""
378+
This method is called in .save() method to update an existing record.
379+
Here we force to update translation fields as well if the original
380+
field only is passed in `save()` in argument `update_fields`.
381+
"""
382+
# TODO: Should the original field (field without lang code suffix) be updated
383+
# when only the default translation field (`field_<DEFAULT_LANG_CODE>`) is passed in `update_fields`?
384+
# Currently, we don't synchronize values of the original and default translation fields in that case.
385+
field_names_to_update = {field.name for field, *_ in values}
386+
387+
translation_values = []
388+
for field, model, value in values:
389+
translation_field_name = rewrite_lookup_key(self.model, field.name)
390+
if translation_field_name not in field_names_to_update:
391+
translatable_field = self.model._meta.get_field(translation_field_name)
392+
translation_values.append((translatable_field, model, value))
393+
394+
values += translation_values
395+
return super()._update(values)
396+
389397
# This method was not present in django-linguo
390398
@property
391399
def _populate_mode(self):

modeltranslation/tests/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _get_database_config():
2929
{
3030
'ENGINE': 'django.db.backends.postgresql',
3131
'USER': os.getenv('POSTGRES_USER', 'postgres'),
32-
'PASSWORD': os.getenv('POSTGRES_DB', 'postgres'),
32+
'PASSWORD': os.getenv('POSTGRES_PASSWORD', 'postgres'),
3333
'NAME': os.getenv('POSTGRES_DB', 'modeltranslation'),
3434
'HOST': host,
3535
}

modeltranslation/tests/tests.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from django.test import TestCase, TransactionTestCase
2323
from django.test.utils import override_settings
2424
from django.utils.translation import get_language, override, trans_real
25+
from parameterized import parameterized
2526

2627
from modeltranslation import admin
2728
from modeltranslation import settings as mt_settings
@@ -79,6 +80,20 @@ def get_field_names(model):
7980
return names
8081

8182

83+
def assert_db_record(instance, **expected_fields):
84+
"""
85+
Compares field values stored in the db.
86+
"""
87+
actual = (
88+
type(instance)
89+
.objects.rewrite(False)
90+
.filter(pk=instance.pk)
91+
.values(*expected_fields.keys())
92+
.first()
93+
)
94+
assert actual == expected_fields
95+
96+
8297
class ModeltranslationTransactionTestBase(TransactionTestCase):
8398
cache = django_apps
8499

@@ -358,6 +373,7 @@ def test_set_translation(self):
358373
assert n.title == title_de
359374
assert n.title_en == title_en
360375
assert n.title_de == title_de
376+
assert_db_record(n, title=title_de, title_de=title_de, title_en=title_en)
361377

362378
# Queries are also language-aware:
363379
assert 1 == models.TestModel.objects.filter(title=title_de).count()
@@ -463,6 +479,89 @@ def test_constructor(self):
463479
)
464480
self._test_constructor(keywords)
465481

482+
@parameterized.expand(
483+
[
484+
({'title': 'DE'}, ['title'], {'title': 'DE', 'title_de': 'DE', 'title_en': None}),
485+
({'title_de': 'DE'}, ['title'], {'title': 'DE', 'title_de': 'DE', 'title_en': None}),
486+
({'title': 'DE'}, ['title_de'], {'title': 'old', 'title_de': 'DE', 'title_en': None}),
487+
(
488+
{'title_de': 'DE'},
489+
['title_de'],
490+
{'title': 'old', 'title_de': 'DE', 'title_en': None},
491+
),
492+
(
493+
{'title': 'DE', 'title_en': 'EN'},
494+
['title', 'title_en'],
495+
{'title': 'DE', 'title_de': 'DE', 'title_en': 'EN'},
496+
),
497+
(
498+
{'title_de': 'DE', 'title_en': 'EN'},
499+
['title_de', 'title_en'],
500+
{'title': 'old', 'title_de': 'DE', 'title_en': 'EN'},
501+
),
502+
(
503+
{'title_de': 'DE', 'title_en': 'EN'},
504+
['title', 'title_de', 'title_en'],
505+
{'title': 'DE', 'title_de': 'DE', 'title_en': 'EN'},
506+
),
507+
]
508+
)
509+
def test_save_original_translation_field(self, field_values, update_fields, expected_db_values):
510+
obj = models.TestModel.objects.create(title='old')
511+
512+
for field, value in field_values.items():
513+
setattr(obj, field, value)
514+
515+
obj.save(update_fields=update_fields)
516+
assert_db_record(obj, **expected_db_values)
517+
518+
@parameterized.expand(
519+
[
520+
({'title': 'EN'}, ['title'], {'title': 'EN', 'title_de': None, 'title_en': 'EN'}),
521+
({'title_en': 'EN'}, ['title'], {'title': 'EN', 'title_de': None, 'title_en': 'EN'}),
522+
({'title': 'EN'}, ['title_en'], {'title': 'old', 'title_de': None, 'title_en': 'EN'}),
523+
(
524+
{'title_en': 'EN'},
525+
['title_en'],
526+
{'title': 'old', 'title_de': None, 'title_en': 'EN'},
527+
),
528+
(
529+
{'title': 'EN', 'title_de': 'DE'},
530+
['title', 'title_de'],
531+
{'title': 'EN', 'title_de': 'DE', 'title_en': 'EN'},
532+
),
533+
(
534+
{'title_de': 'DE', 'title_en': 'EN'},
535+
['title_de', 'title_en'],
536+
{'title': 'old', 'title_de': 'DE', 'title_en': 'EN'},
537+
),
538+
(
539+
{'title_de': 'DE', 'title_en': 'EN'},
540+
['title', 'title_de', 'title_en'],
541+
{'title': 'EN', 'title_de': 'DE', 'title_en': 'EN'},
542+
),
543+
]
544+
)
545+
def test_save_active_translation_field(self, field_values, update_fields, expected_db_values):
546+
with override('en'):
547+
obj = models.TestModel.objects.create(title='old')
548+
549+
for field, value in field_values.items():
550+
setattr(obj, field, value)
551+
552+
obj.save(update_fields=update_fields)
553+
assert_db_record(obj, **expected_db_values)
554+
555+
def test_save_non_original_translation_field(self):
556+
obj = models.TestModel.objects.create(title='old')
557+
558+
obj.title_en = 'en value'
559+
obj.save(update_fields=['title'])
560+
assert_db_record(obj, title='old', title_de='old', title_en=None)
561+
562+
obj.save(update_fields=['title_en'])
563+
assert_db_record(obj, title='old', title_de='old', title_en='en value')
564+
466565
def test_update_or_create_existing(self):
467566
"""
468567
Test that update_or_create works as expected
@@ -477,6 +576,43 @@ def test_update_or_create_existing(self):
477576
assert instance.title == 'NEW DE TITLE'
478577
assert instance.title_en == 'old en'
479578
assert instance.title_de == 'NEW DE TITLE'
579+
assert_db_record(
580+
instance,
581+
title='NEW DE TITLE',
582+
title_en='old en',
583+
title_de='NEW DE TITLE',
584+
)
585+
586+
instance, created = models.TestModel.objects.update_or_create(
587+
pk=obj.pk, defaults={'title_de': 'NEW DE TITLE 2'}
588+
)
589+
590+
assert created is False
591+
assert instance.title == 'NEW DE TITLE 2'
592+
assert instance.title_en == 'old en'
593+
assert instance.title_de == 'NEW DE TITLE 2'
594+
assert_db_record(
595+
instance,
596+
# title='NEW DE TITLE', # TODO: django < 4.2 doesn't pass `"title"` into `.save(update_fields)`
597+
title_en='old en',
598+
title_de='NEW DE TITLE 2',
599+
)
600+
601+
with override('en'):
602+
instance, created = models.TestModel.objects.update_or_create(
603+
pk=obj.pk, defaults={'title': 'NEW EN TITLE'}
604+
)
605+
606+
assert created is False
607+
assert instance.title == 'NEW EN TITLE'
608+
assert instance.title_en == 'NEW EN TITLE'
609+
assert instance.title_de == 'NEW DE TITLE 2'
610+
assert_db_record(
611+
instance,
612+
title='NEW EN TITLE',
613+
title_en='NEW EN TITLE',
614+
title_de='NEW DE TITLE 2',
615+
)
480616

481617
def test_update_or_create_new(self):
482618
instance, created = models.TestModel.objects.update_or_create(
@@ -488,6 +624,12 @@ def test_update_or_create_new(self):
488624
assert instance.title == 'old de'
489625
assert instance.title_en == 'old en'
490626
assert instance.title_de == 'old de'
627+
assert_db_record(
628+
instance,
629+
title='old de',
630+
title_en='old en',
631+
title_de='old de',
632+
)
491633

492634

493635
class ModeltranslationTransactionTest(ModeltranslationTransactionTestBase):

0 commit comments

Comments
 (0)