Skip to content

Commit

Permalink
Fix adding parameter definition when db has parameter types
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen committed Sep 13, 2024
1 parent e022a81 commit cd7367d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
21 changes: 16 additions & 5 deletions spinedb_api/db_mapping_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from difflib import SequenceMatcher
from enum import Enum, auto, unique
from multiprocessing import RLock
from typing import Set
from .exception import SpineDBAPIError
from .helpers import Asterisk
from .temp_id import TempId, resolve
Expand Down Expand Up @@ -634,12 +635,13 @@ def _same_item(self, mapped_item, db_item):

def check_fields(self, item, valid_types=()):
factory = self._db_map.item_factory(self._item_type)
field_union = factory.internal_external_private_fields() | {
"id",
"commit_id",
}

def _error(key, value, valid_types):
if key in set(factory._internal_fields) | set(factory._external_fields) | factory._private_fields | {
"id",
"commit_id",
}:
if key in field_union:
# The user seems to know what they're doing
return
f_dict = factory.fields.get(key)
Expand Down Expand Up @@ -780,10 +782,19 @@ def ref_types(cls):
"""Returns a set of item types that this class refers.
Returns:
set(str)
set of str
"""
return set(cls._references.values())

@classmethod
def internal_external_private_fields(cls) -> Set[str]:
"""Returns a union of internal, external and private fields.
Returns:
set of str: field union
"""
return set(cls._internal_fields) | set(cls._external_fields) | cls._private_fields

@property
def db_map(self) -> DatabaseMappingBase:
"""Returns the database mapping of the item."""
Expand Down
12 changes: 11 additions & 1 deletion spinedb_api/mapped_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,21 @@ def __getitem__(self, key):

def _sorted_parameter_types(self):
self._db_map.do_fetch_all("parameter_type")
if "id" in self:
return sorted(
(
x
for x in self._db_map.mapped_table("parameter_type").valid_values()
if x["parameter_definition_id"] == dict.__getitem__(self, "id")
),
key=lambda i: (i["type"], i["rank"]),
)
return sorted(
(
x
for x in self._db_map.mapped_table("parameter_type").valid_values()
if x["parameter_definition_id"] == self["id"]
if x["parameter_definition_name"] == dict.__getitem__(self, "name")
and x["entity_class_name"] == dict.__getitem__(self, "entity_class_name")
),
key=lambda i: (i["type"], i["rank"]),
)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_DatabaseMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,20 @@ def test_do_fetch_more_in_chunks(self):
gadgets = db_map.get_items("entity", entity_class_name="Gadget")
self.assertEqual(len(gadgets), 1)

def test_add_parameter_definition_to_database_with_parameter_types_does_not_raise_key_error(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self._assert_success(db_map.add_entity_class_item(name="Key"))
repeat_rate = self._assert_success(
db_map.add_parameter_definition_item(
name="repeat rate", entity_class_name="Key", parameter_type_list=("float",)
)
)
self.assertEqual(repeat_rate["parameter_type_list"], ("float",))
is_useful = self._assert_success(
db_map.add_parameter_definition_item(name="is useful", entity_class_name="Key")
)
self.assertNotEqual(is_useful, {})


class TestDatabaseMappingLegacy(unittest.TestCase):
"""'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure."""
Expand Down

0 comments on commit cd7367d

Please sign in to comment.