Skip to content

Commit

Permalink
Merge pull request #44 from lidatong/fix-decoder-hook-and-make-iso-field
Browse files Browse the repository at this point in the history
Fix decoder hook and add mm custom iso field
  • Loading branch information
lidatong authored Nov 18, 2018
2 parents 149758f + fbcfee1 commit 10f0dd3
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 18 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,18 @@ class DataClassWithIsoDatetime:
}})
```

As you can see, you can override:
- `encoder`: a callable, which will be invoked to convert the field value when calling `to_json`
- `decoder`: a callable, which will be invoked to convert the JSON value when calling `from_json`
As you can see, you can override the default codecs by providing a "hook" via a
callable:
- `encoder`: a callable, which will be invoked to convert the field value when encoding to JSON
- `decoder`: a callable, which will be invoked to convert the JSON value when decoding from JSON
- `mm_field`: a marshmallow field, which will affect the behavior of any operations involving `.schema()`

Note that these hooks will be invoked regardless if you're using
`.to_json`/`dump`/`dumps
and `.from_json`/`load`/`loads`. So apply overrides judiciously, making sure to
carefully consider whether the interaction of the encode/decode/mm_field
overrides is consistent with what you expect!

## A larger example

```python
Expand Down
11 changes: 5 additions & 6 deletions dataclasses_json/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dataclasses_json import mm
from dataclasses_json.core import (_ExtendedEncoder, _asdict, _decode_dataclass,
_field_overrides, _issubclass_safe,
_overrides, _issubclass_safe,
_override)

A = TypeVar('A')
Expand All @@ -35,7 +35,7 @@ def to_json(self,
default: Callable = None,
sort_keys: bool = False,
**kw) -> str:
kvs = _override(_asdict(self), _field_overrides(self), 'encoder')
kvs = _override(_asdict(self), _overrides(self), 'encoder')
return json.dumps(kvs,
cls=_ExtendedEncoder,
skipkeys=skipkeys,
Expand Down Expand Up @@ -64,8 +64,7 @@ def from_json(cls: A,
parse_int=parse_int,
parse_constant=parse_constant,
**kw)
overriden_kvs = _override(kvs, _field_overrides(cls), 'decoder')
return _decode_dataclass(cls, overriden_kvs, infer_missing)
return _decode_dataclass(cls, kvs, infer_missing)

@classmethod
def schema(cls,
Expand All @@ -85,10 +84,10 @@ def schema(cls,

@post_load
def make_instance(self, kvs):
return _decode_dataclass(cls, kvs, infer_missing=partial)
return _decode_dataclass(cls, kvs, partial)

overriden_fields = {k: v.mm_field
for k, v in _field_overrides(cls).items()}
for k, v in _overrides(cls).items()}
generated_fields = {field for field in fields(cls)
if field.name not in overriden_fields}
nested_fields = mm._make_nested_fields(generated_fields,
Expand Down
17 changes: 13 additions & 4 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def default(self, o) -> JSON:
return result


def _field_overrides(dc):
def _overrides(dc):
overrides = {}
attrs = ['encoder', 'decoder', 'mm_field']
FieldOverride = namedtuple('FieldOverride', attrs)
Expand All @@ -56,6 +56,7 @@ def _override(kvs, overrides, attr):


def _decode_dataclass(cls, kvs, infer_missing):
overrides = _overrides(cls)
kvs = {} if kvs is None and infer_missing else kvs
missing_fields = {field for field in fields(cls) if field.name not in kvs}
for field in missing_fields:
Expand All @@ -81,6 +82,14 @@ def _decode_dataclass(cls, kvs, infer_missing):
else:
warnings.warn(f"`NoneType` object {warning}.", RuntimeWarning)
init_kwargs[field.name] = field_value
elif (field.name in overrides
and overrides[field.name].decoder is not None):
# FIXME hack
if field.type is type(field_value):
init_kwargs[field.name] = field_value
else:
init_kwargs[field.name] = overrides[field.name].decoder(
field_value)
elif is_dataclass(field.type):
# FIXME this is a band-aid to deal with the value already being
# serialized when handling nested marshmallow schema
Expand All @@ -89,8 +98,7 @@ def _decode_dataclass(cls, kvs, infer_missing):
if is_dataclass(field_value):
value = field_value
else:
value = _decode_dataclass(field.type,
field_value,
value = _decode_dataclass(field.type, field_value,
infer_missing)
init_kwargs[field.name] = value

Expand Down Expand Up @@ -174,7 +182,8 @@ def _decode_items(type_arg, xs, infer_missing):
hence the check of `is_dataclass(vs)`
"""
if is_dataclass(type_arg) or is_dataclass(xs):
items = (_decode_dataclass(type_arg, x, infer_missing) for x in xs)
items = (_decode_dataclass(type_arg, x, infer_missing)
for x in xs)
elif _is_supported_generic(type_arg):
items = (_decode_generic(type_arg, x, infer_missing) for x in xs)
else:
Expand Down
8 changes: 8 additions & 0 deletions dataclasses_json/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def _deserialize(self, value, attr, data, **kwargs):
return _timestamp_to_dt_aware(value)


class _IsoField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
return value.isoformat()

def _deserialize(self, value, attr, data, **kwargs):
return datetime.fromisoformat(value)


_type_to_cons = {
dict: fields.Dict,
list: fields.List,
Expand Down
12 changes: 12 additions & 0 deletions tests/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from marshmallow import fields

from dataclasses_json import DataClassJsonMixin, dataclass_json
from dataclasses_json.mm import _IsoField

A = TypeVar('A')

Expand Down Expand Up @@ -172,6 +173,17 @@ class DataClassWithIsoDatetime:
'mm_field': fields.DateTime(format='iso')
}})

@dataclass_json
@dataclass
class DataClassWithCustomIsoDatetime:
created_at: datetime = field(
metadata={'dataclasses_json': {
'encoder': datetime.isoformat,
'decoder': datetime.fromisoformat,
'mm_field': _IsoField()
}})



@dataclass_json
@dataclass
Expand Down
15 changes: 12 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
DataClassWithDataClass, DataClassWithDatetime,
DataClassWithList, DataClassWithOptional,
DataClassWithOptionalNested, DataClassWithUuid,
DataClassWithIsoDatetime, DataClassWithOverride)
DataClassWithIsoDatetime, DataClassWithOverride,
DataClassWithCustomIsoDatetime)


class TestTypes:
Expand Down Expand Up @@ -60,13 +61,21 @@ def test_datetime_override_schema_encode(self):

def test_datetime_override_schema_decode(self):
iso = DataClassWithIsoDatetime.schema().loads(self.dc_iso_json)
# FIXME bug in marshmallow currently returns naive instead of the
# document aware. also seems to drop microseconds?
# FIXME bug in marshmallow currently returns datetime-naive instead of
# datetime-aware. also seems to drop microseconds?
# #955
iso.created_at = iso.created_at.replace(microsecond=456753,
tzinfo=self.tz)
assert (iso == self.dc_iso)

def test_datetime_custom_iso_fieldoverride_schema_encode(self):
assert (DataClassWithCustomIsoDatetime.schema().dumps(self.dc_iso)
== self.dc_iso_json)

def test_datetime_custom_iso_field_override_schema_decode(self):
iso = DataClassWithCustomIsoDatetime.schema().loads(self.dc_iso_json)
assert (iso == DataClassWithCustomIsoDatetime(self.dt))


class TestInferMissing:
def test_infer_missing(self):
Expand Down
12 changes: 10 additions & 2 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@

@dataclass
class Car(DataClassJsonMixin):
license_number: str = field(metadata={'mm': fields.String(required=False)})
license_number: str = field(
metadata={'dataclasses_json': {
'mm_field': fields.String(required=False)}
})


@dataclass
class StringDate(DataClassJsonMixin):
string_date: datetime.datetime = field(metadata={'mm': fields.String(required=False)})
string_date: datetime.datetime = field(
metadata={'dataclasses_json': {
'encoder': str,
'decoder': str,
'mm_field': fields.String(required=False)}
})


car_schema = Car.schema()
Expand Down

0 comments on commit 10f0dd3

Please sign in to comment.