From d83479850fcb18504843e80ea707c046e64b3b58 Mon Sep 17 00:00:00 2001 From: Lukas Anzinger Date: Thu, 21 Nov 2024 15:15:52 +0100 Subject: [PATCH 1/2] Add tests for QuerySelectField with get_group --- tests/test_main.py | 108 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 96 insertions(+), 12 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 5574681..9e1926f 100755 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -40,6 +40,20 @@ def __call__(self, field, **kwargs): ) +class LazyGroupSelect: + def __call__(self, field, **kwargs): + return list( + ( + group, + list( + (val, str(label), selected, render_kw) + for val, label, selected, render_kw in choices + ), + ) + for group, choices in field.iter_groups() + ) + + class Base: def __init__(self, **kwargs): for k, v in iter(kwargs.items()): @@ -83,7 +97,7 @@ def _do_tables(self, mapper, engine): mapper_registry.metadata.create_all(bind=engine) def _fill(self, sess): - for i, n in [(1, "apple"), (2, "banana")]: + for i, n in [(1, "apple"), (2, "banana"), (3, "apricot")]: s = self.Test(id=i, name=n) p = self.PKTest(foobar=f"hello{i}", baz=n) sess.add(s) @@ -117,7 +131,12 @@ class F(Form): self.assertTrue(form.a.data is not None) self.assertEqual(form.a.data.id, 1) self.assertEqual( - form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})] + form.a(), + [ + ("1", "apple", True, {}), + ("2", "banana", False, {}), + ("3", "apricot", False, {}), + ], ) self.assertTrue(form.validate()) @@ -155,11 +174,24 @@ class F(Form): query_factory=lambda: sess.query(self.PKTest), widget=LazySelect(), ) + d = QuerySelectField( + allow_blank=True, + blank_text="", + blank_value="", + query_factory=lambda: sess.query(self.PKTest), + get_group=lambda x: x.baz[0], + widget=LazyGroupSelect(), + ) form = F() self.assertEqual(form.a.data, None) self.assertEqual( - form.a(), [("1", "apple", False, {}), ("2", "banana", False, {})] + form.a(), + [ + ("1", "apple", False, {}), + ("2", "banana", False, {}), + ("3", "apricot", False, {}), + ], ) self.assertEqual(form.b.data, None) self.assertEqual( @@ -168,6 +200,7 @@ class F(Form): ("__None", "", True, {}), ("hello1", "apple", False, {}), ("hello2", "banana", False, {}), + ("hello3", "apricot", False, {}), ], ) self.assertEqual(form.c.data, None) @@ -177,14 +210,31 @@ class F(Form): ("", "", True, {}), ("hello1", "apple", False, {}), ("hello2", "banana", False, {}), + ("hello3", "apricot", False, {}), + ], + ) + self.assertEqual(form.d.data, None) + self.assertEqual( + form.d(), + [ + ( + "a", + [("hello1", "apple", False, {}), ("hello3", "apricot", False, {})], + ), + ("b", [("hello2", "banana", False, {})]), ], ) self.assertFalse(form.validate()) - form = F(DummyPostData(a=["1"], b=["hello2"], c=[""])) + form = F(DummyPostData(a=["1"], b=["hello2"], c=[""], d=["hello3"])) self.assertEqual(form.a.data.id, 1) self.assertEqual( - form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})] + form.a(), + [ + ("1", "apple", True, {}), + ("2", "banana", False, {}), + ("3", "apricot", False, {}), + ], ) self.assertEqual(form.b.data.baz, "banana") self.assertEqual( @@ -193,6 +243,7 @@ class F(Form): ("__None", "", False, {}), ("hello1", "apple", False, {}), ("hello2", "banana", True, {}), + ("hello3", "apricot", False, {}), ], ) self.assertEqual(form.c.data, None) @@ -202,16 +253,33 @@ class F(Form): ("", "", True, {}), ("hello1", "apple", False, {}), ("hello2", "banana", False, {}), + ("hello3", "apricot", False, {}), + ], + ) + self.assertEqual(form.d.data.baz, "apricot") + self.assertEqual( + form.d(), + [ + ( + "a", + [("hello1", "apple", False, {}), ("hello3", "apricot", True, {})], + ), + ("b", [("hello2", "banana", False, {})]), ], ) self.assertTrue(form.validate()) # Make sure the query is cached - sess.add(self.Test(id=3, name="meh")) + sess.add(self.Test(id=4, name="meh")) sess.flush() sess.commit() self.assertEqual( - form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})] + form.a(), + [ + ("1", "apple", True, {}), + ("2", "banana", False, {}), + ("3", "apricot", False, {}), + ], ) form.a._object_list = None self.assertEqual( @@ -219,7 +287,8 @@ class F(Form): [ ("1", "apple", True, {}), ("2", "banana", False, {}), - ("3", "meh", False, {}), + ("3", "apricot", False, {}), + ("4", "meh", False, {}), ], ) @@ -264,7 +333,12 @@ def test_single_value_without_factory(self): form.a.query = self.sess.query(self.Test) self.assertEqual([1], [v.id for v in form.a.data]) self.assertEqual( - form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})] + form.a(), + [ + ("1", "apple", True, {}), + ("2", "banana", False, {}), + ("3", "apricot", False, {}), + ], ) self.assertTrue(form.validate()) @@ -273,11 +347,16 @@ def test_multiple_values_without_query_factory(self): form.a.query = self.sess.query(self.Test) self.assertEqual([1, 2], [v.id for v in form.a.data]) self.assertEqual( - form.a(), [("1", "apple", True, {}), ("2", "banana", True, {})] + form.a(), + [ + ("1", "apple", True, {}), + ("2", "banana", True, {}), + ("3", "apricot", False, {}), + ], ) self.assertTrue(form.validate()) - form = self.F(DummyPostData(a=["1", "3"])) + form = self.F(DummyPostData(a=["1", "4"])) form.a.query = self.sess.query(self.Test) self.assertEqual([x.id for x in form.a.data], [1]) self.assertFalse(form.validate()) @@ -296,7 +375,12 @@ class F(Form): form = F() self.assertEqual([v.id for v in form.a.data], [2]) self.assertEqual( - form.a(), [("1", "apple", False, {}), ("2", "banana", True, {})] + form.a(), + [ + ("1", "apple", False, {}), + ("2", "banana", True, {}), + ("3", "apricot", False, {}), + ], ) self.assertTrue(form.validate()) From bf45825763ae0325e8bba67488a5be5459956f7c Mon Sep 17 00:00:00 2001 From: Lukas Anzinger Date: Thu, 21 Nov 2024 15:38:12 +0100 Subject: [PATCH 2/2] QuerySelectField: Add support for blank choice when using groups --- pyproject.toml | 2 +- src/wtforms_sqlalchemy/fields.py | 8 +++++++- tests/test_main.py | 2 ++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 125e6a6..2579829 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ ] requires-python = ">=3.9" dependencies = [ - "WTForms>=3.1", + "WTForms>=3.1.2", "SQLAlchemy>=1.4", ] dynamic = ["version"] diff --git a/src/wtforms_sqlalchemy/fields.py b/src/wtforms_sqlalchemy/fields.py index b23099c..14aed3e 100644 --- a/src/wtforms_sqlalchemy/fields.py +++ b/src/wtforms_sqlalchemy/fields.py @@ -147,9 +147,12 @@ def _get_object_list(self): self._object_list = list((str(get_pk(obj)), obj) for obj in query) return self._object_list + def _get_blank_choice(self): + return (self.blank_value, self.blank_text, self.data is None, {}) + def iter_choices(self): if self.allow_blank: - yield (self.blank_value, self.blank_text, self.data is None, {}) + yield self._get_blank_choice() for pk, obj in self._get_object_list(): yield (pk, self.get_label(obj), obj == self.data, self.get_render_kw(obj)) @@ -159,6 +162,9 @@ def has_groups(self): def iter_groups(self): if self.has_groups(): + if self.allow_blank: + yield (None, [self._get_blank_choice()]) + groups = defaultdict(list) for pk, obj in self._get_object_list(): groups[self.get_group(obj)].append((pk, obj)) diff --git a/tests/test_main.py b/tests/test_main.py index 9e1926f..ee83072 100755 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -217,6 +217,7 @@ class F(Form): self.assertEqual( form.d(), [ + (None, [("", "", True, {})]), ( "a", [("hello1", "apple", False, {}), ("hello3", "apricot", False, {})], @@ -260,6 +261,7 @@ class F(Form): self.assertEqual( form.d(), [ + (None, [("", "", False, {})]), ( "a", [("hello1", "apple", False, {}), ("hello3", "apricot", True, {})],