diff --git a/src/pymatgen/util/testing.py b/src/pymatgen/util/testing.py index 31159571b3f..4e9bb8bbccf 100644 --- a/src/pymatgen/util/testing.py +++ b/src/pymatgen/util/testing.py @@ -79,16 +79,21 @@ def assert_msonable(obj: Any, test_is_subclass: bool = True) -> str: Returns: str: Serialized object. """ + obj_name = obj.__class__.__name__ + + # Check if is an instance of MONable (or its subclasses) if test_is_subclass and not isinstance(obj, MSONable): - raise TypeError("obj is not MSONable") + raise TypeError(f"{obj_name} object is not MSONable") + # Check if the object can be accurately reconstructed from its dict representation if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict(): - raise ValueError("obj could not be reconstructed accurately from its dict representation.") + raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.") + # Verify that the deserialized object's class is a subclass of the original object's class json_str = json.dumps(obj.as_dict(), cls=MontyEncoder) round_trip = json.loads(json_str, cls=MontyDecoder) if not issubclass(type(round_trip), type(obj)): - raise TypeError(f"{type(round_trip)} != {type(obj)}") + raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}") return json_str @staticmethod diff --git a/tests/util/test_testing.py b/tests/util/test_testing.py index 1402f9db9c5..a8b845bc62f 100644 --- a/tests/util/test_testing.py +++ b/tests/util/test_testing.py @@ -3,8 +3,10 @@ import json import os from pathlib import Path +from unittest.mock import patch import pytest +from monty.json import MontyDecoder from pymatgen.core import Element, Structure from pymatgen.io.vasp.inputs import Kpoints @@ -54,13 +56,6 @@ def test_creating_files_in_tmp_dir(self): class TestAssertMSONable(PymatgenTest): - """TODO: - - test: raise TypeError("obj is not MSONable") - - test: raise ValueError("obj could not be reconstructed accurately from its dict representation.") - - test: if not issubclass(type(round_trip), type(obj)): - raise TypeError(f"{type(round_trip)} != {type(obj)}") - """ - def test_valid_msonable(self): """Test a valid MSONable object.""" kpts_obj = Kpoints.monkhorst_automatic((2, 2, 2), [0, 0, 0]) @@ -86,6 +81,43 @@ def test_valid_msonable(self): assert is_np_dict_equal(serialized, expected_result) + def test_non_msonable(self): + non_msonable = dict(hello="world") + # Test `test_is_subclass` is True + with pytest.raises(TypeError, match="dict object is not MSONable"): + self.assert_msonable(non_msonable) + + # Test `test_is_subclass` is False (dict don't have `as_dict` method) + with pytest.raises(AttributeError, match="'dict' object has no attribute 'as_dict'"): + self.assert_msonable(non_msonable, test_is_subclass=False) + + def test_cannot_reconstruct(self): + """Patch the `from_dict` method of `Kpoints` to return a corrupted object""" + kpts_obj = Kpoints.monkhorst_automatic((2, 2, 2), [0, 0, 0]) + + with patch.object(Kpoints, "from_dict", side_effect=lambda d: Kpoints(comment="Corrupted Object")): + reconstructed_obj = Kpoints.from_dict(kpts_obj.as_dict()) + assert reconstructed_obj.comment == "Corrupted Object" + + with pytest.raises(ValueError, match="Kpoints object could not be reconstructed accurately"): + self.assert_msonable(kpts_obj) + + def test_not_round_trip(self): + kpts_obj = Kpoints.monkhorst_automatic((2, 2, 2), [0, 0, 0]) + + # Patch the MontyDecoder to return an object of a different class + class NotAKpoints: + pass + + with patch.object(MontyDecoder, "process_decoded", side_effect=lambda d: NotAKpoints()) as mock_decoder: + with pytest.raises( + TypeError, + match="The reconstructed NotAKpoints object is not a subclass of Kpoints", + ): + self.assert_msonable(kpts_obj) + + mock_decoder.assert_called() + class TestPymatgenTest(PymatgenTest): def test_assert_str_content_equal(self):