Skip to content

Commit

Permalink
more human readable err msg and test
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Nov 30, 2024
1 parent 47b26fe commit 6126945
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
11 changes: 8 additions & 3 deletions src/pymatgen/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,21 @@ def assert_msonable(obj: Any, test_is_subclass: bool = True) -> str:
Returns:
str: Serialized object.
"""
obj_name = obj.__class__.__name__

# Check if is an instance of MONable (or its subclasses)
if test_is_subclass and not isinstance(obj, MSONable):
raise TypeError("obj is not MSONable")
raise TypeError(f"{obj_name} object is not MSONable")

# Check if the object can be accurately reconstructed from its dict representation
if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict():
raise ValueError("obj could not be reconstructed accurately from its dict representation.")
raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.")

# Verify that the deserialized object's class is a subclass of the original object's class
json_str = json.dumps(obj.as_dict(), cls=MontyEncoder)
round_trip = json.loads(json_str, cls=MontyDecoder)
if not issubclass(type(round_trip), type(obj)):
raise TypeError(f"{type(round_trip)} != {type(obj)}")
raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}")
return json_str

@staticmethod
Expand Down
46 changes: 39 additions & 7 deletions tests/util/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import json
import os
from pathlib import Path
from unittest.mock import patch

import pytest
from monty.json import MontyDecoder

from pymatgen.core import Element, Structure
from pymatgen.io.vasp.inputs import Kpoints
Expand Down Expand Up @@ -54,13 +56,6 @@ def test_creating_files_in_tmp_dir(self):


class TestAssertMSONable(PymatgenTest):
"""TODO:
- test: raise TypeError("obj is not MSONable")
- test: raise ValueError("obj could not be reconstructed accurately from its dict representation.")
- test: if not issubclass(type(round_trip), type(obj)):
raise TypeError(f"{type(round_trip)} != {type(obj)}")
"""

def test_valid_msonable(self):
"""Test a valid MSONable object."""
kpts_obj = Kpoints.monkhorst_automatic((2, 2, 2), [0, 0, 0])
Expand All @@ -86,6 +81,43 @@ def test_valid_msonable(self):

assert is_np_dict_equal(serialized, expected_result)

def test_non_msonable(self):
non_msonable = dict(hello="world")
# Test `test_is_subclass` is True
with pytest.raises(TypeError, match="dict object is not MSONable"):
self.assert_msonable(non_msonable)

# Test `test_is_subclass` is False (dict don't have `as_dict` method)
with pytest.raises(AttributeError, match="'dict' object has no attribute 'as_dict'"):
self.assert_msonable(non_msonable, test_is_subclass=False)

def test_cannot_reconstruct(self):
"""Patch the `from_dict` method of `Kpoints` to return a corrupted object"""
kpts_obj = Kpoints.monkhorst_automatic((2, 2, 2), [0, 0, 0])

with patch.object(Kpoints, "from_dict", side_effect=lambda d: Kpoints(comment="Corrupted Object")):
reconstructed_obj = Kpoints.from_dict(kpts_obj.as_dict())
assert reconstructed_obj.comment == "Corrupted Object"

with pytest.raises(ValueError, match="Kpoints object could not be reconstructed accurately"):
self.assert_msonable(kpts_obj)

def test_not_round_trip(self):
kpts_obj = Kpoints.monkhorst_automatic((2, 2, 2), [0, 0, 0])

# Patch the MontyDecoder to return an object of a different class
class NotAKpoints:
pass

with patch.object(MontyDecoder, "process_decoded", side_effect=lambda d: NotAKpoints()) as mock_decoder:
with pytest.raises(
TypeError,
match="The reconstructed NotAKpoints object is not a subclass of Kpoints",
):
self.assert_msonable(kpts_obj)

mock_decoder.assert_called()


class TestPymatgenTest(PymatgenTest):
def test_assert_str_content_equal(self):
Expand Down

0 comments on commit 6126945

Please sign in to comment.