Skip to content

Commit

Permalink
Support pydantic BaseModel classes in state (#983)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwwillchen authored Sep 26, 2024
1 parent dd5647e commit 3051e48
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 20 deletions.
4 changes: 2 additions & 2 deletions mesop/dataclass_utils/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYDANTIC", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")

package(
default_visibility = ["//build_defs:mesop_internal"],
Expand All @@ -13,7 +13,7 @@ py_library(
deps = [
"//mesop/components/uploader:uploaded_file",
"//mesop/exceptions",
] + THIRD_PARTY_PY_DEEPDIFF,
] + THIRD_PARTY_PY_DEEPDIFF + THIRD_PARTY_PY_PYDANTIC,
)

py_test(
Expand Down
58 changes: 44 additions & 14 deletions mesop/dataclass_utils/dataclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
from deepdiff import DeepDiff, Delta
from deepdiff.operator import BaseOperator
from deepdiff.path import parse_path
from pydantic import BaseModel

from mesop.components.uploader.uploaded_file import UploadedFile
from mesop.exceptions import MesopDeveloperException, MesopException

_PANDAS_OBJECT_KEY = "__pandas.DataFrame__"
_PYDANTIC_OBJECT_KEY = "__pydantic.BaseModel__"
_DATETIME_OBJECT_KEY = "__datetime.datetime__"
_BYTES_OBJECT_KEY = "__python.bytes__"
_SET_OBJECT_KEY = "__python.set__"
_UPLOADED_FILE_OBJECT_KEY = "__mesop.UploadedFile__"
_DIFF_ACTION_DATA_FRAME_CHANGED = "data_frame_changed"
_DIFF_ACTION_UPLOADED_FILE_CHANGED = "mesop_uploaded_file_changed"
_DIFF_ACTION_EQUALITY_CHANGED = "mesop_equality_changed"

C = TypeVar("C")

Expand All @@ -36,6 +38,8 @@ def _check_has_pandas():

_has_pandas = _check_has_pandas()

pydantic_model_cache = {}


def dataclass_with_defaults(cls: Type[C]) -> Type[C]:
"""
Expand Down Expand Up @@ -64,6 +68,14 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]:

annotations = get_type_hints(cls)
for name, type_hint in annotations.items():
if (
isinstance(type_hint, type)
and has_parent(type_hint)
and issubclass(type_hint, BaseModel)
):
pydantic_model_cache[(type_hint.__module__, type_hint.__qualname__)] = (
type_hint
)
if name not in cls.__dict__: # Skip if default already set
if type_hint == int:
setattr(cls, name, field(default=0))
Expand Down Expand Up @@ -187,6 +199,15 @@ def default(self, obj):
}
}

if isinstance(obj, BaseModel):
return {
_PYDANTIC_OBJECT_KEY: {
"json": obj.model_dump_json(),
"module": obj.__class__.__module__,
"qualname": obj.__class__.__qualname__,
}
}

if isinstance(obj, datetime):
return {_DATETIME_OBJECT_KEY: obj.isoformat()}

Expand Down Expand Up @@ -221,6 +242,18 @@ def decode_mesop_json_state_hook(dct):
if _PANDAS_OBJECT_KEY in dct:
return pd.read_json(StringIO(dct[_PANDAS_OBJECT_KEY]), orient="table")

if _PYDANTIC_OBJECT_KEY in dct:
cache_key = (
dct[_PYDANTIC_OBJECT_KEY]["module"],
dct[_PYDANTIC_OBJECT_KEY]["qualname"],
)
if cache_key not in pydantic_model_cache:
raise MesopException(
f"Tried to deserialize Pydantic model, but it's not in the cache: {cache_key}"
)
model_class = pydantic_model_cache[cache_key]
return model_class.model_validate_json(dct[_PYDANTIC_OBJECT_KEY]["json"])

if _DATETIME_OBJECT_KEY in dct:
return datetime.fromisoformat(dct[_DATETIME_OBJECT_KEY])

Expand Down Expand Up @@ -269,25 +302,22 @@ def give_up_diffing(self, level, diff_instance) -> bool:
return True


class UploadedFileOperator(BaseOperator):
"""Custom operator to detect changes in UploadedFile class.
class EqualityOperator(BaseOperator):
"""Custom operator to detect changes with direct equality.
DeepDiff does not diff the UploadedFile class correctly, so we will just use a normal
equality check, rather than diffing further into the io.BytesIO parent class.
This class could probably be made more generic to handle other classes where we want
to diff using equality checks.
"""

def match(self, level) -> bool:
return isinstance(level.t1, UploadedFile) and isinstance(
level.t2, UploadedFile
return isinstance(level.t1, (UploadedFile, BaseModel)) and isinstance(
level.t2, (UploadedFile, BaseModel)
)

def give_up_diffing(self, level, diff_instance) -> bool:
if level.t1 != level.t2:
diff_instance.custom_report_result(
_DIFF_ACTION_UPLOADED_FILE_CHANGED, level, {"value": level.t2}
_DIFF_ACTION_EQUALITY_CHANGED, level, {"value": level.t2}
)
return True

Expand All @@ -306,7 +336,7 @@ def diff_state(state1: Any, state2: Any) -> str:
raise MesopException("Tried to diff state which was not a dataclass")

custom_actions = []
custom_operators = [UploadedFileOperator()]
custom_operators = [EqualityOperator()]
# Only use the `DataFrameOperator` if pandas exists.
if _has_pandas:
differences = DeepDiff(
Expand All @@ -328,15 +358,15 @@ def diff_state(state1: Any, state2: Any) -> str:
else:
differences = DeepDiff(state1, state2, custom_operators=custom_operators)

# Manually format UploadedFile diffs to flat dict format.
if _DIFF_ACTION_UPLOADED_FILE_CHANGED in differences:
# Manually format diffs to flat dict format.
if _DIFF_ACTION_EQUALITY_CHANGED in differences:
custom_actions = [
{
"path": parse_path(path),
"action": _DIFF_ACTION_UPLOADED_FILE_CHANGED,
"action": _DIFF_ACTION_EQUALITY_CHANGED,
**diff,
}
for path, diff in differences[_DIFF_ACTION_UPLOADED_FILE_CHANGED].items()
for path, diff in differences[_DIFF_ACTION_EQUALITY_CHANGED].items()
]

# Handle the set case which will have a modified path after being JSON encoded.
Expand Down
82 changes: 82 additions & 0 deletions mesop/dataclass_utils/dataclass_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from pydantic import BaseModel

import mesop.protos.ui_pb2 as pb
from mesop.components.uploader.uploaded_file import UploadedFile
Expand Down Expand Up @@ -49,6 +50,35 @@ class WithUploadedFile:
data: UploadedFile = field(default_factory=UploadedFile)


class NestedPydanticModel(BaseModel):
default_value: str = "default"
no_default_value: str


class PydanticModel(BaseModel):
name: str = "World"
counter: int = 0
list_models: list[NestedPydanticModel] = field(default_factory=lambda: [])
nested: NestedPydanticModel = field(
default_factory=lambda: NestedPydanticModel(
no_default_value="<no_default_factory>"
)
)
optional_value: str | None = None
union_value: str | int = 0
tuple_value: tuple[str, int] = ("a", 1)


@dataclass_with_defaults
class WithPydanticModel:
data: PydanticModel


@dataclass_with_defaults
class WithPydanticModelDefaultFactory:
default_factory: PydanticModel = field(default_factory=PydanticModel)


JSON_STR = """{"b": {"c": {"val": "<init>"}},
"list_b": [
{"c": {"val": "1"}},
Expand Down Expand Up @@ -180,6 +210,58 @@ def test_serialize_uploaded_file():
)


def test_serialize_deserialize_pydantic_model():
state = WithPydanticModel()
state.data.name = "Hello"
state.data.counter = 1
state.data.nested = NestedPydanticModel(no_default_value="no_default")
state.data.list_models.append(
NestedPydanticModel(no_default_value="no_default_list_model_val_1")
)
state.data.list_models.append(
NestedPydanticModel(no_default_value="no_default_list_model_val_2")
)
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_set_optional_value():
state = WithPydanticModel()
state.data.optional_value = "optional"
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_set_union_value():
state = WithPydanticModel()
state.data.union_value = "union_value"
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_set_tuple_value():
state = WithPydanticModel()
state.data.tuple_value = ("tuple_value", 1)
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_default_factory():
state = WithPydanticModelDefaultFactory()
state.default_factory.name = "Hello"
state.default_factory.counter = 1
state.default_factory.nested = NestedPydanticModel(
no_default_value="no_default"
)
new_state = WithPydanticModelDefaultFactory()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


@pytest.mark.parametrize(
"input_bytes, expected_json",
[
Expand Down
30 changes: 29 additions & 1 deletion mesop/dataclass_utils/diff_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
import pytest
from pydantic import BaseModel

from mesop.components.uploader.uploaded_file import UploadedFile
from mesop.dataclass_utils.dataclass_utils import diff_state
Expand Down Expand Up @@ -409,7 +410,7 @@ class C:
assert json.loads(diff_state(s1, s2)) == [
{
"path": ["data"],
"action": "mesop_uploaded_file_changed",
"action": "mesop_equality_changed",
"value": {
"__mesop.UploadedFile__": {
"contents": "ZGF0YQ==",
Expand All @@ -422,6 +423,33 @@ class C:
]


def test_diff_pydantic_model():
class PydanticModel(BaseModel):
name: str = "World"
counter: int = 0

@dataclass
class C:
data: PydanticModel

s1 = C(data=PydanticModel())
s2 = C(data=PydanticModel(name="Hello", counter=1))

assert json.loads(diff_state(s1, s2)) == [
{
"path": ["data"],
"action": "mesop_equality_changed",
"value": {
"__pydantic.BaseModel__": {
"json": '{"name":"Hello","counter":1}',
"module": "dataclass_utils.diff_state_test",
"qualname": "test_diff_pydantic_model.<locals>.PydanticModel",
},
},
}
]


def test_diff_uploaded_file_same_no_diff():
@dataclass
class C:
Expand Down
1 change: 1 addition & 0 deletions mesop/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from mesop.examples import on_load_generator as on_load_generator
from mesop.examples import playground as playground
from mesop.examples import playground_critic as playground_critic
from mesop.examples import pydantic_state as pydantic_state
from mesop.examples import query_params as query_params
from mesop.examples import readme_app as readme_app
from mesop.examples import responsive_layout as responsive_layout
Expand Down
27 changes: 27 additions & 0 deletions mesop/examples/pydantic_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pydantic import BaseModel

import mesop as me


class PydanticModel(BaseModel):
name: str = "World"
counter: int = 0


@me.stateclass
class State:
model: PydanticModel


@me.page(path="/pydantic_state")
def main():
state = me.state(State)
me.text(f"Name: {state.model.name}")
me.text(f"Counter: {state.model.counter}")

me.button("Increment Counter", on_click=on_click)


def on_click(e: me.ClickEvent):
state = me.state(State)
state.model.counter += 1
14 changes: 14 additions & 0 deletions mesop/tests/e2e/pydantic_state_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import {test, expect} from '@playwright/test';

test('pydantic state is serialized and deserialized properly', async ({
page,
}) => {
await page.goto('/pydantic_state');

await expect(page.getByText('Name: world')).toBeVisible();
await expect(page.getByText('Counter: 0')).toBeVisible();
await page.getByRole('button', {name: 'Increment Counter'}).click();
await expect(page.getByText('Counter: 1')).toBeVisible();
await page.getByRole('button', {name: 'Increment Counter'}).click();
await expect(page.getByText('Counter: 2')).toBeVisible();
});
4 changes: 2 additions & 2 deletions mesop/web/src/utils/diff.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function applyComponentDiff(component: Component, diff: ComponentDiff) {
const STATE_DIFF_VALUES_CHANGED = 'values_changed';
const STATE_DIFF_TYPE_CHANGES = 'type_changes';
const STATE_DIFF_DATA_FRAME_CHANGED = 'data_frame_changed';
const STATE_DIFF_UPLOADED_FILE_CHANGED = 'mesop_uploaded_file_changed';
const STATE_DIFF_EQUALITY_CHANGED = 'mesop_equality_changed';
const STATE_DIFF_ITERABLE_ITEM_REMOVED = 'iterable_item_removed';
const STATE_DIFF_ITERABLE_ITEM_ADDED = 'iterable_item_added';
const STATE_DIFF_SET_ITEM_REMOVED = 'set_item_removed';
Expand Down Expand Up @@ -118,7 +118,7 @@ export function applyStateDiff(stateJson: string, diffJson: string): string {
row.action === STATE_DIFF_VALUES_CHANGED ||
row.action === STATE_DIFF_TYPE_CHANGES ||
row.action === STATE_DIFF_DATA_FRAME_CHANGED ||
row.action === STATE_DIFF_UPLOADED_FILE_CHANGED
row.action === STATE_DIFF_EQUALITY_CHANGED
) {
updateValue(root, row.path, row.value);
} else if (row.action === STATE_DIFF_DICT_ITEM_ADDED) {
Expand Down
2 changes: 1 addition & 1 deletion mesop/web/src/utils/diff_state_spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ describe('applyStateDiff functionality', () => {
const diff = JSON.stringify([
{
path: ['data'],
action: 'mesop_uploaded_file_changed',
action: 'mesop_equality_changed',
value: {
'__mesop.UploadedFile__': {
'contents': 'data',
Expand Down

0 comments on commit 3051e48

Please sign in to comment.