Skip to content

Commit 3051e48

Browse files
authored
Support pydantic BaseModel classes in state (#983)
1 parent dd5647e commit 3051e48

File tree

9 files changed

+202
-20
lines changed

9 files changed

+202
-20
lines changed

mesop/dataclass_utils/BUILD

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")
1+
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYDANTIC", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")
22

33
package(
44
default_visibility = ["//build_defs:mesop_internal"],
@@ -13,7 +13,7 @@ py_library(
1313
deps = [
1414
"//mesop/components/uploader:uploaded_file",
1515
"//mesop/exceptions",
16-
] + THIRD_PARTY_PY_DEEPDIFF,
16+
] + THIRD_PARTY_PY_DEEPDIFF + THIRD_PARTY_PY_PYDANTIC,
1717
)
1818

1919
py_test(

mesop/dataclass_utils/dataclass_utils.py

+44-14
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@
99
from deepdiff import DeepDiff, Delta
1010
from deepdiff.operator import BaseOperator
1111
from deepdiff.path import parse_path
12+
from pydantic import BaseModel
1213

1314
from mesop.components.uploader.uploaded_file import UploadedFile
1415
from mesop.exceptions import MesopDeveloperException, MesopException
1516

1617
_PANDAS_OBJECT_KEY = "__pandas.DataFrame__"
18+
_PYDANTIC_OBJECT_KEY = "__pydantic.BaseModel__"
1719
_DATETIME_OBJECT_KEY = "__datetime.datetime__"
1820
_BYTES_OBJECT_KEY = "__python.bytes__"
1921
_SET_OBJECT_KEY = "__python.set__"
2022
_UPLOADED_FILE_OBJECT_KEY = "__mesop.UploadedFile__"
2123
_DIFF_ACTION_DATA_FRAME_CHANGED = "data_frame_changed"
22-
_DIFF_ACTION_UPLOADED_FILE_CHANGED = "mesop_uploaded_file_changed"
24+
_DIFF_ACTION_EQUALITY_CHANGED = "mesop_equality_changed"
2325

2426
C = TypeVar("C")
2527

@@ -36,6 +38,8 @@ def _check_has_pandas():
3638

3739
_has_pandas = _check_has_pandas()
3840

41+
pydantic_model_cache = {}
42+
3943

4044
def dataclass_with_defaults(cls: Type[C]) -> Type[C]:
4145
"""
@@ -64,6 +68,14 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]:
6468

6569
annotations = get_type_hints(cls)
6670
for name, type_hint in annotations.items():
71+
if (
72+
isinstance(type_hint, type)
73+
and has_parent(type_hint)
74+
and issubclass(type_hint, BaseModel)
75+
):
76+
pydantic_model_cache[(type_hint.__module__, type_hint.__qualname__)] = (
77+
type_hint
78+
)
6779
if name not in cls.__dict__: # Skip if default already set
6880
if type_hint == int:
6981
setattr(cls, name, field(default=0))
@@ -187,6 +199,15 @@ def default(self, obj):
187199
}
188200
}
189201

202+
if isinstance(obj, BaseModel):
203+
return {
204+
_PYDANTIC_OBJECT_KEY: {
205+
"json": obj.model_dump_json(),
206+
"module": obj.__class__.__module__,
207+
"qualname": obj.__class__.__qualname__,
208+
}
209+
}
210+
190211
if isinstance(obj, datetime):
191212
return {_DATETIME_OBJECT_KEY: obj.isoformat()}
192213

@@ -221,6 +242,18 @@ def decode_mesop_json_state_hook(dct):
221242
if _PANDAS_OBJECT_KEY in dct:
222243
return pd.read_json(StringIO(dct[_PANDAS_OBJECT_KEY]), orient="table")
223244

245+
if _PYDANTIC_OBJECT_KEY in dct:
246+
cache_key = (
247+
dct[_PYDANTIC_OBJECT_KEY]["module"],
248+
dct[_PYDANTIC_OBJECT_KEY]["qualname"],
249+
)
250+
if cache_key not in pydantic_model_cache:
251+
raise MesopException(
252+
f"Tried to deserialize Pydantic model, but it's not in the cache: {cache_key}"
253+
)
254+
model_class = pydantic_model_cache[cache_key]
255+
return model_class.model_validate_json(dct[_PYDANTIC_OBJECT_KEY]["json"])
256+
224257
if _DATETIME_OBJECT_KEY in dct:
225258
return datetime.fromisoformat(dct[_DATETIME_OBJECT_KEY])
226259

@@ -269,25 +302,22 @@ def give_up_diffing(self, level, diff_instance) -> bool:
269302
return True
270303

271304

272-
class UploadedFileOperator(BaseOperator):
273-
"""Custom operator to detect changes in UploadedFile class.
305+
class EqualityOperator(BaseOperator):
306+
"""Custom operator to detect changes with direct equality.
274307
275308
DeepDiff does not diff the UploadedFile class correctly, so we will just use a normal
276309
equality check, rather than diffing further into the io.BytesIO parent class.
277-
278-
This class could probably be made more generic to handle other classes where we want
279-
to diff using equality checks.
280310
"""
281311

282312
def match(self, level) -> bool:
283-
return isinstance(level.t1, UploadedFile) and isinstance(
284-
level.t2, UploadedFile
313+
return isinstance(level.t1, (UploadedFile, BaseModel)) and isinstance(
314+
level.t2, (UploadedFile, BaseModel)
285315
)
286316

287317
def give_up_diffing(self, level, diff_instance) -> bool:
288318
if level.t1 != level.t2:
289319
diff_instance.custom_report_result(
290-
_DIFF_ACTION_UPLOADED_FILE_CHANGED, level, {"value": level.t2}
320+
_DIFF_ACTION_EQUALITY_CHANGED, level, {"value": level.t2}
291321
)
292322
return True
293323

@@ -306,7 +336,7 @@ def diff_state(state1: Any, state2: Any) -> str:
306336
raise MesopException("Tried to diff state which was not a dataclass")
307337

308338
custom_actions = []
309-
custom_operators = [UploadedFileOperator()]
339+
custom_operators = [EqualityOperator()]
310340
# Only use the `DataFrameOperator` if pandas exists.
311341
if _has_pandas:
312342
differences = DeepDiff(
@@ -328,15 +358,15 @@ def diff_state(state1: Any, state2: Any) -> str:
328358
else:
329359
differences = DeepDiff(state1, state2, custom_operators=custom_operators)
330360

331-
# Manually format UploadedFile diffs to flat dict format.
332-
if _DIFF_ACTION_UPLOADED_FILE_CHANGED in differences:
361+
# Manually format diffs to flat dict format.
362+
if _DIFF_ACTION_EQUALITY_CHANGED in differences:
333363
custom_actions = [
334364
{
335365
"path": parse_path(path),
336-
"action": _DIFF_ACTION_UPLOADED_FILE_CHANGED,
366+
"action": _DIFF_ACTION_EQUALITY_CHANGED,
337367
**diff,
338368
}
339-
for path, diff in differences[_DIFF_ACTION_UPLOADED_FILE_CHANGED].items()
369+
for path, diff in differences[_DIFF_ACTION_EQUALITY_CHANGED].items()
340370
]
341371

342372
# Handle the set case which will have a modified path after being JSON encoded.

mesop/dataclass_utils/dataclass_utils_test.py

+82
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pandas as pd
66
import pytest
7+
from pydantic import BaseModel
78

89
import mesop.protos.ui_pb2 as pb
910
from mesop.components.uploader.uploaded_file import UploadedFile
@@ -49,6 +50,35 @@ class WithUploadedFile:
4950
data: UploadedFile = field(default_factory=UploadedFile)
5051

5152

53+
class NestedPydanticModel(BaseModel):
54+
default_value: str = "default"
55+
no_default_value: str
56+
57+
58+
class PydanticModel(BaseModel):
59+
name: str = "World"
60+
counter: int = 0
61+
list_models: list[NestedPydanticModel] = field(default_factory=lambda: [])
62+
nested: NestedPydanticModel = field(
63+
default_factory=lambda: NestedPydanticModel(
64+
no_default_value="<no_default_factory>"
65+
)
66+
)
67+
optional_value: str | None = None
68+
union_value: str | int = 0
69+
tuple_value: tuple[str, int] = ("a", 1)
70+
71+
72+
@dataclass_with_defaults
73+
class WithPydanticModel:
74+
data: PydanticModel
75+
76+
77+
@dataclass_with_defaults
78+
class WithPydanticModelDefaultFactory:
79+
default_factory: PydanticModel = field(default_factory=PydanticModel)
80+
81+
5282
JSON_STR = """{"b": {"c": {"val": "<init>"}},
5383
"list_b": [
5484
{"c": {"val": "1"}},
@@ -180,6 +210,58 @@ def test_serialize_uploaded_file():
180210
)
181211

182212

213+
def test_serialize_deserialize_pydantic_model():
214+
state = WithPydanticModel()
215+
state.data.name = "Hello"
216+
state.data.counter = 1
217+
state.data.nested = NestedPydanticModel(no_default_value="no_default")
218+
state.data.list_models.append(
219+
NestedPydanticModel(no_default_value="no_default_list_model_val_1")
220+
)
221+
state.data.list_models.append(
222+
NestedPydanticModel(no_default_value="no_default_list_model_val_2")
223+
)
224+
new_state = WithPydanticModel()
225+
update_dataclass_from_json(new_state, serialize_dataclass(state))
226+
assert new_state == state
227+
228+
229+
def test_serialize_deserialize_pydantic_model_set_optional_value():
230+
state = WithPydanticModel()
231+
state.data.optional_value = "optional"
232+
new_state = WithPydanticModel()
233+
update_dataclass_from_json(new_state, serialize_dataclass(state))
234+
assert new_state == state
235+
236+
237+
def test_serialize_deserialize_pydantic_model_set_union_value():
238+
state = WithPydanticModel()
239+
state.data.union_value = "union_value"
240+
new_state = WithPydanticModel()
241+
update_dataclass_from_json(new_state, serialize_dataclass(state))
242+
assert new_state == state
243+
244+
245+
def test_serialize_deserialize_pydantic_model_set_tuple_value():
246+
state = WithPydanticModel()
247+
state.data.tuple_value = ("tuple_value", 1)
248+
new_state = WithPydanticModel()
249+
update_dataclass_from_json(new_state, serialize_dataclass(state))
250+
assert new_state == state
251+
252+
253+
def test_serialize_deserialize_pydantic_model_default_factory():
254+
state = WithPydanticModelDefaultFactory()
255+
state.default_factory.name = "Hello"
256+
state.default_factory.counter = 1
257+
state.default_factory.nested = NestedPydanticModel(
258+
no_default_value="no_default"
259+
)
260+
new_state = WithPydanticModelDefaultFactory()
261+
update_dataclass_from_json(new_state, serialize_dataclass(state))
262+
assert new_state == state
263+
264+
183265
@pytest.mark.parametrize(
184266
"input_bytes, expected_json",
185267
[

mesop/dataclass_utils/diff_state_test.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pandas as pd
77
import pytest
8+
from pydantic import BaseModel
89

910
from mesop.components.uploader.uploaded_file import UploadedFile
1011
from mesop.dataclass_utils.dataclass_utils import diff_state
@@ -409,7 +410,7 @@ class C:
409410
assert json.loads(diff_state(s1, s2)) == [
410411
{
411412
"path": ["data"],
412-
"action": "mesop_uploaded_file_changed",
413+
"action": "mesop_equality_changed",
413414
"value": {
414415
"__mesop.UploadedFile__": {
415416
"contents": "ZGF0YQ==",
@@ -422,6 +423,33 @@ class C:
422423
]
423424

424425

426+
def test_diff_pydantic_model():
427+
class PydanticModel(BaseModel):
428+
name: str = "World"
429+
counter: int = 0
430+
431+
@dataclass
432+
class C:
433+
data: PydanticModel
434+
435+
s1 = C(data=PydanticModel())
436+
s2 = C(data=PydanticModel(name="Hello", counter=1))
437+
438+
assert json.loads(diff_state(s1, s2)) == [
439+
{
440+
"path": ["data"],
441+
"action": "mesop_equality_changed",
442+
"value": {
443+
"__pydantic.BaseModel__": {
444+
"json": '{"name":"Hello","counter":1}',
445+
"module": "dataclass_utils.diff_state_test",
446+
"qualname": "test_diff_pydantic_model.<locals>.PydanticModel",
447+
},
448+
},
449+
}
450+
]
451+
452+
425453
def test_diff_uploaded_file_same_no_diff():
426454
@dataclass
427455
class C:

mesop/examples/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from mesop.examples import on_load_generator as on_load_generator
3636
from mesop.examples import playground as playground
3737
from mesop.examples import playground_critic as playground_critic
38+
from mesop.examples import pydantic_state as pydantic_state
3839
from mesop.examples import query_params as query_params
3940
from mesop.examples import readme_app as readme_app
4041
from mesop.examples import responsive_layout as responsive_layout

mesop/examples/pydantic_state.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from pydantic import BaseModel
2+
3+
import mesop as me
4+
5+
6+
class PydanticModel(BaseModel):
7+
name: str = "World"
8+
counter: int = 0
9+
10+
11+
@me.stateclass
12+
class State:
13+
model: PydanticModel
14+
15+
16+
@me.page(path="/pydantic_state")
17+
def main():
18+
state = me.state(State)
19+
me.text(f"Name: {state.model.name}")
20+
me.text(f"Counter: {state.model.counter}")
21+
22+
me.button("Increment Counter", on_click=on_click)
23+
24+
25+
def on_click(e: me.ClickEvent):
26+
state = me.state(State)
27+
state.model.counter += 1
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import {test, expect} from '@playwright/test';
2+
3+
test('pydantic state is serialized and deserialized properly', async ({
4+
page,
5+
}) => {
6+
await page.goto('/pydantic_state');
7+
8+
await expect(page.getByText('Name: world')).toBeVisible();
9+
await expect(page.getByText('Counter: 0')).toBeVisible();
10+
await page.getByRole('button', {name: 'Increment Counter'}).click();
11+
await expect(page.getByText('Counter: 1')).toBeVisible();
12+
await page.getByRole('button', {name: 'Increment Counter'}).click();
13+
await expect(page.getByText('Counter: 2')).toBeVisible();
14+
});

mesop/web/src/utils/diff.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export function applyComponentDiff(component: Component, diff: ComponentDiff) {
7878
const STATE_DIFF_VALUES_CHANGED = 'values_changed';
7979
const STATE_DIFF_TYPE_CHANGES = 'type_changes';
8080
const STATE_DIFF_DATA_FRAME_CHANGED = 'data_frame_changed';
81-
const STATE_DIFF_UPLOADED_FILE_CHANGED = 'mesop_uploaded_file_changed';
81+
const STATE_DIFF_EQUALITY_CHANGED = 'mesop_equality_changed';
8282
const STATE_DIFF_ITERABLE_ITEM_REMOVED = 'iterable_item_removed';
8383
const STATE_DIFF_ITERABLE_ITEM_ADDED = 'iterable_item_added';
8484
const STATE_DIFF_SET_ITEM_REMOVED = 'set_item_removed';
@@ -118,7 +118,7 @@ export function applyStateDiff(stateJson: string, diffJson: string): string {
118118
row.action === STATE_DIFF_VALUES_CHANGED ||
119119
row.action === STATE_DIFF_TYPE_CHANGES ||
120120
row.action === STATE_DIFF_DATA_FRAME_CHANGED ||
121-
row.action === STATE_DIFF_UPLOADED_FILE_CHANGED
121+
row.action === STATE_DIFF_EQUALITY_CHANGED
122122
) {
123123
updateValue(root, row.path, row.value);
124124
} else if (row.action === STATE_DIFF_DICT_ITEM_ADDED) {

mesop/web/src/utils/diff_state_spec.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ describe('applyStateDiff functionality', () => {
388388
const diff = JSON.stringify([
389389
{
390390
path: ['data'],
391-
action: 'mesop_uploaded_file_changed',
391+
action: 'mesop_equality_changed',
392392
value: {
393393
'__mesop.UploadedFile__': {
394394
'contents': 'data',

0 commit comments

Comments
 (0)