@@ -42,7 +42,7 @@ class ModelFieldsTestMixin(TestCase):
4242 """
4343
4444 def assert_model_fields_list_equal (self , db_list : list [Model ], ground_truths : list [dict ],
45- ignore_fields : list [str ], field_maps = {} ):
45+ ignore_fields : list [str ], field_maps : dict | None = None ):
4646 """
4747 List wrapper for assert_model_fields_equal.
4848 """
@@ -57,18 +57,18 @@ def assert_model_fields_list_equal(self, db_list: list[Model], ground_truths: li
5757 )
5858
5959 def assert_model_fields_equal (self , db_obj : Model , ground_truth : dict ,
60- ignore_fields : list [str ], field_maps = {} ):
60+ ignore_fields : list [str ], field_maps : dict | None = None ):
6161 """
6262 Compares the fields of db_obj (exluding ignore_fields, if any) with the values of ground_truth.
6363 """
64- MODEL_FIELDS = [f .name for f in db_obj ._meta .get_fields () if f .name not in ignore_fields ]
65- for field in MODEL_FIELDS :
64+ model_fields = [f .name for f in db_obj ._meta .get_fields () if f .name not in ignore_fields ]
65+ for field in model_fields :
6666 gt_value = ground_truth .get (field )
6767 if gt_value and field == "extra_properties" :
6868 # remove non-ingested computed properties from gt to compare
6969 gt_value = remove_computed_properties (gt_value )
7070 # Apply field mapping, if any
71- model_field = field_maps .get (field , field )
71+ model_field = ( field_maps or {}) .get (field , field )
7272 if gt_value :
7373 # we expect the db_obj to contain this ground truth value
7474 self .assertEqual (getattr (db_obj , model_field ), gt_value )
0 commit comments