diff --git a/mesop/dataclass_utils/BUILD b/mesop/dataclass_utils/BUILD index f86151989..b4d09e62e 100644 --- a/mesop/dataclass_utils/BUILD +++ b/mesop/dataclass_utils/BUILD @@ -1,4 +1,4 @@ -load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test") +load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYDANTIC", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test") package( default_visibility = ["//build_defs:mesop_internal"], @@ -13,7 +13,7 @@ py_library( deps = [ "//mesop/components/uploader:uploaded_file", "//mesop/exceptions", - ] + THIRD_PARTY_PY_DEEPDIFF, + ] + THIRD_PARTY_PY_DEEPDIFF + THIRD_PARTY_PY_PYDANTIC, ) py_test( diff --git a/mesop/dataclass_utils/dataclass_utils.py b/mesop/dataclass_utils/dataclass_utils.py index fc6108ee3..fd01ebc8b 100644 --- a/mesop/dataclass_utils/dataclass_utils.py +++ b/mesop/dataclass_utils/dataclass_utils.py @@ -9,17 +9,19 @@ from deepdiff import DeepDiff, Delta from deepdiff.operator import BaseOperator from deepdiff.path import parse_path +from pydantic import BaseModel from mesop.components.uploader.uploaded_file import UploadedFile from mesop.exceptions import MesopDeveloperException, MesopException _PANDAS_OBJECT_KEY = "__pandas.DataFrame__" +_PYDANTIC_OBJECT_KEY = "__pydantic.BaseModel__" _DATETIME_OBJECT_KEY = "__datetime.datetime__" _BYTES_OBJECT_KEY = "__python.bytes__" _SET_OBJECT_KEY = "__python.set__" _UPLOADED_FILE_OBJECT_KEY = "__mesop.UploadedFile__" _DIFF_ACTION_DATA_FRAME_CHANGED = "data_frame_changed" -_DIFF_ACTION_UPLOADED_FILE_CHANGED = "mesop_uploaded_file_changed" +_DIFF_ACTION_EQUALITY_CHANGED = "mesop_equality_changed" C = TypeVar("C") @@ -36,6 +38,8 @@ def _check_has_pandas(): _has_pandas = _check_has_pandas() +pydantic_model_cache = {} + def dataclass_with_defaults(cls: Type[C]) -> Type[C]: """ @@ -64,6 +68,14 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]: annotations = get_type_hints(cls) for name, type_hint in annotations.items(): + if ( + isinstance(type_hint, type) + and has_parent(type_hint) + and issubclass(type_hint, BaseModel) + ): + pydantic_model_cache[(type_hint.__module__, type_hint.__qualname__)] = ( + type_hint + ) if name not in cls.__dict__: # Skip if default already set if type_hint == int: setattr(cls, name, field(default=0)) @@ -187,6 +199,15 @@ def default(self, obj): } } + if isinstance(obj, BaseModel): + return { + _PYDANTIC_OBJECT_KEY: { + "json": obj.model_dump_json(), + "module": obj.__class__.__module__, + "qualname": obj.__class__.__qualname__, + } + } + if isinstance(obj, datetime): return {_DATETIME_OBJECT_KEY: obj.isoformat()} @@ -221,6 +242,18 @@ def decode_mesop_json_state_hook(dct): if _PANDAS_OBJECT_KEY in dct: return pd.read_json(StringIO(dct[_PANDAS_OBJECT_KEY]), orient="table") + if _PYDANTIC_OBJECT_KEY in dct: + cache_key = ( + dct[_PYDANTIC_OBJECT_KEY]["module"], + dct[_PYDANTIC_OBJECT_KEY]["qualname"], + ) + if cache_key not in pydantic_model_cache: + raise MesopException( + f"Tried to deserialize Pydantic model, but it's not in the cache: {cache_key}" + ) + model_class = pydantic_model_cache[cache_key] + return model_class.model_validate_json(dct[_PYDANTIC_OBJECT_KEY]["json"]) + if _DATETIME_OBJECT_KEY in dct: return datetime.fromisoformat(dct[_DATETIME_OBJECT_KEY]) @@ -269,25 +302,22 @@ def give_up_diffing(self, level, diff_instance) -> bool: return True -class UploadedFileOperator(BaseOperator): - """Custom operator to detect changes in UploadedFile class. +class EqualityOperator(BaseOperator): + """Custom operator to detect changes with direct equality. DeepDiff does not diff the UploadedFile class correctly, so we will just use a normal equality check, rather than diffing further into the io.BytesIO parent class. - - This class could probably be made more generic to handle other classes where we want - to diff using equality checks. """ def match(self, level) -> bool: - return isinstance(level.t1, UploadedFile) and isinstance( - level.t2, UploadedFile + return isinstance(level.t1, (UploadedFile, BaseModel)) and isinstance( + level.t2, (UploadedFile, BaseModel) ) def give_up_diffing(self, level, diff_instance) -> bool: if level.t1 != level.t2: diff_instance.custom_report_result( - _DIFF_ACTION_UPLOADED_FILE_CHANGED, level, {"value": level.t2} + _DIFF_ACTION_EQUALITY_CHANGED, level, {"value": level.t2} ) return True @@ -306,7 +336,7 @@ def diff_state(state1: Any, state2: Any) -> str: raise MesopException("Tried to diff state which was not a dataclass") custom_actions = [] - custom_operators = [UploadedFileOperator()] + custom_operators = [EqualityOperator()] # Only use the `DataFrameOperator` if pandas exists. if _has_pandas: differences = DeepDiff( @@ -328,15 +358,15 @@ def diff_state(state1: Any, state2: Any) -> str: else: differences = DeepDiff(state1, state2, custom_operators=custom_operators) - # Manually format UploadedFile diffs to flat dict format. - if _DIFF_ACTION_UPLOADED_FILE_CHANGED in differences: + # Manually format diffs to flat dict format. + if _DIFF_ACTION_EQUALITY_CHANGED in differences: custom_actions = [ { "path": parse_path(path), - "action": _DIFF_ACTION_UPLOADED_FILE_CHANGED, + "action": _DIFF_ACTION_EQUALITY_CHANGED, **diff, } - for path, diff in differences[_DIFF_ACTION_UPLOADED_FILE_CHANGED].items() + for path, diff in differences[_DIFF_ACTION_EQUALITY_CHANGED].items() ] # Handle the set case which will have a modified path after being JSON encoded. diff --git a/mesop/dataclass_utils/dataclass_utils_test.py b/mesop/dataclass_utils/dataclass_utils_test.py index 7fadc3301..2cd5bd2b9 100644 --- a/mesop/dataclass_utils/dataclass_utils_test.py +++ b/mesop/dataclass_utils/dataclass_utils_test.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pytest +from pydantic import BaseModel import mesop.protos.ui_pb2 as pb from mesop.components.uploader.uploaded_file import UploadedFile @@ -49,6 +50,35 @@ class WithUploadedFile: data: UploadedFile = field(default_factory=UploadedFile) +class NestedPydanticModel(BaseModel): + default_value: str = "default" + no_default_value: str + + +class PydanticModel(BaseModel): + name: str = "World" + counter: int = 0 + list_models: list[NestedPydanticModel] = field(default_factory=lambda: []) + nested: NestedPydanticModel = field( + default_factory=lambda: NestedPydanticModel( + no_default_value="" + ) + ) + optional_value: str | None = None + union_value: str | int = 0 + tuple_value: tuple[str, int] = ("a", 1) + + +@dataclass_with_defaults +class WithPydanticModel: + data: PydanticModel + + +@dataclass_with_defaults +class WithPydanticModelDefaultFactory: + default_factory: PydanticModel = field(default_factory=PydanticModel) + + JSON_STR = """{"b": {"c": {"val": ""}}, "list_b": [ {"c": {"val": "1"}}, @@ -180,6 +210,58 @@ def test_serialize_uploaded_file(): ) +def test_serialize_deserialize_pydantic_model(): + state = WithPydanticModel() + state.data.name = "Hello" + state.data.counter = 1 + state.data.nested = NestedPydanticModel(no_default_value="no_default") + state.data.list_models.append( + NestedPydanticModel(no_default_value="no_default_list_model_val_1") + ) + state.data.list_models.append( + NestedPydanticModel(no_default_value="no_default_list_model_val_2") + ) + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_set_optional_value(): + state = WithPydanticModel() + state.data.optional_value = "optional" + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_set_union_value(): + state = WithPydanticModel() + state.data.union_value = "union_value" + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_set_tuple_value(): + state = WithPydanticModel() + state.data.tuple_value = ("tuple_value", 1) + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_default_factory(): + state = WithPydanticModelDefaultFactory() + state.default_factory.name = "Hello" + state.default_factory.counter = 1 + state.default_factory.nested = NestedPydanticModel( + no_default_value="no_default" + ) + new_state = WithPydanticModelDefaultFactory() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + @pytest.mark.parametrize( "input_bytes, expected_json", [ diff --git a/mesop/dataclass_utils/diff_state_test.py b/mesop/dataclass_utils/diff_state_test.py index fa52a0fc4..1dd8f06e9 100644 --- a/mesop/dataclass_utils/diff_state_test.py +++ b/mesop/dataclass_utils/diff_state_test.py @@ -5,6 +5,7 @@ import pandas as pd import pytest +from pydantic import BaseModel from mesop.components.uploader.uploaded_file import UploadedFile from mesop.dataclass_utils.dataclass_utils import diff_state @@ -409,7 +410,7 @@ class C: assert json.loads(diff_state(s1, s2)) == [ { "path": ["data"], - "action": "mesop_uploaded_file_changed", + "action": "mesop_equality_changed", "value": { "__mesop.UploadedFile__": { "contents": "ZGF0YQ==", @@ -422,6 +423,33 @@ class C: ] +def test_diff_pydantic_model(): + class PydanticModel(BaseModel): + name: str = "World" + counter: int = 0 + + @dataclass + class C: + data: PydanticModel + + s1 = C(data=PydanticModel()) + s2 = C(data=PydanticModel(name="Hello", counter=1)) + + assert json.loads(diff_state(s1, s2)) == [ + { + "path": ["data"], + "action": "mesop_equality_changed", + "value": { + "__pydantic.BaseModel__": { + "json": '{"name":"Hello","counter":1}', + "module": "dataclass_utils.diff_state_test", + "qualname": "test_diff_pydantic_model..PydanticModel", + }, + }, + } + ] + + def test_diff_uploaded_file_same_no_diff(): @dataclass class C: diff --git a/mesop/examples/__init__.py b/mesop/examples/__init__.py index 21109712f..bd0bb87c7 100644 --- a/mesop/examples/__init__.py +++ b/mesop/examples/__init__.py @@ -35,6 +35,7 @@ from mesop.examples import on_load_generator as on_load_generator from mesop.examples import playground as playground from mesop.examples import playground_critic as playground_critic +from mesop.examples import pydantic_state as pydantic_state from mesop.examples import query_params as query_params from mesop.examples import readme_app as readme_app from mesop.examples import responsive_layout as responsive_layout diff --git a/mesop/examples/pydantic_state.py b/mesop/examples/pydantic_state.py new file mode 100644 index 000000000..53c3943b6 --- /dev/null +++ b/mesop/examples/pydantic_state.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +import mesop as me + + +class PydanticModel(BaseModel): + name: str = "World" + counter: int = 0 + + +@me.stateclass +class State: + model: PydanticModel + + +@me.page(path="/pydantic_state") +def main(): + state = me.state(State) + me.text(f"Name: {state.model.name}") + me.text(f"Counter: {state.model.counter}") + + me.button("Increment Counter", on_click=on_click) + + +def on_click(e: me.ClickEvent): + state = me.state(State) + state.model.counter += 1 diff --git a/mesop/tests/e2e/pydantic_state_test.ts b/mesop/tests/e2e/pydantic_state_test.ts new file mode 100644 index 000000000..dd5a7b589 --- /dev/null +++ b/mesop/tests/e2e/pydantic_state_test.ts @@ -0,0 +1,14 @@ +import {test, expect} from '@playwright/test'; + +test('pydantic state is serialized and deserialized properly', async ({ + page, +}) => { + await page.goto('/pydantic_state'); + + await expect(page.getByText('Name: world')).toBeVisible(); + await expect(page.getByText('Counter: 0')).toBeVisible(); + await page.getByRole('button', {name: 'Increment Counter'}).click(); + await expect(page.getByText('Counter: 1')).toBeVisible(); + await page.getByRole('button', {name: 'Increment Counter'}).click(); + await expect(page.getByText('Counter: 2')).toBeVisible(); +}); diff --git a/mesop/web/src/utils/diff.ts b/mesop/web/src/utils/diff.ts index c479eca66..6f77e904c 100644 --- a/mesop/web/src/utils/diff.ts +++ b/mesop/web/src/utils/diff.ts @@ -78,7 +78,7 @@ export function applyComponentDiff(component: Component, diff: ComponentDiff) { const STATE_DIFF_VALUES_CHANGED = 'values_changed'; const STATE_DIFF_TYPE_CHANGES = 'type_changes'; const STATE_DIFF_DATA_FRAME_CHANGED = 'data_frame_changed'; -const STATE_DIFF_UPLOADED_FILE_CHANGED = 'mesop_uploaded_file_changed'; +const STATE_DIFF_EQUALITY_CHANGED = 'mesop_equality_changed'; const STATE_DIFF_ITERABLE_ITEM_REMOVED = 'iterable_item_removed'; const STATE_DIFF_ITERABLE_ITEM_ADDED = 'iterable_item_added'; const STATE_DIFF_SET_ITEM_REMOVED = 'set_item_removed'; @@ -118,7 +118,7 @@ export function applyStateDiff(stateJson: string, diffJson: string): string { row.action === STATE_DIFF_VALUES_CHANGED || row.action === STATE_DIFF_TYPE_CHANGES || row.action === STATE_DIFF_DATA_FRAME_CHANGED || - row.action === STATE_DIFF_UPLOADED_FILE_CHANGED + row.action === STATE_DIFF_EQUALITY_CHANGED ) { updateValue(root, row.path, row.value); } else if (row.action === STATE_DIFF_DICT_ITEM_ADDED) { diff --git a/mesop/web/src/utils/diff_state_spec.ts b/mesop/web/src/utils/diff_state_spec.ts index 608ea3ad3..7aac12d40 100644 --- a/mesop/web/src/utils/diff_state_spec.ts +++ b/mesop/web/src/utils/diff_state_spec.ts @@ -388,7 +388,7 @@ describe('applyStateDiff functionality', () => { const diff = JSON.stringify([ { path: ['data'], - action: 'mesop_uploaded_file_changed', + action: 'mesop_equality_changed', value: { '__mesop.UploadedFile__': { 'contents': 'data',