Skip to content

Commit

Permalink
Make database mapping work when fetching external changes (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Jan 30, 2024
2 parents 474d3a7 + 89e44fd commit d187d29
Show file tree
Hide file tree
Showing 14 changed files with 1,826 additions and 425 deletions.
104 changes: 104 additions & 0 deletions spinedb_api/conflict_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# This file is part of Spine Database API.
# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
from __future__ import annotations
from enum import auto, Enum, unique
from dataclasses import dataclass

from .item_status import Status


@unique
class Resolution(Enum):
USE_IN_MEMORY = auto()
USE_IN_DB = auto()


@dataclass
class Conflict:
in_memory: MappedItemBase
in_db: MappedItemBase


@dataclass
class Resolved(Conflict):
resolution: Resolution

def __init__(self, conflict, resolution):
self.in_memory = conflict.in_memory
self.in_db = conflict.in_db
self.resolution = resolution


def select_in_memory_item_always(conflicts):
return [Resolved(conflict, Resolution.USE_IN_MEMORY) for conflict in conflicts]


def select_in_db_item_always(conflicts):
return [Resolved(conflict, Resolution.USE_IN_DB) for conflict in conflicts]


@dataclass
class KeepInMemoryAction:
in_memory: MappedItemBase
set_uncommitted: bool

def __init__(self, conflict):
self.in_memory = conflict.in_memory
self.set_uncommitted = not conflict.in_memory.equal_ignoring_ids(conflict.in_db)


@dataclass
class UpdateInMemoryAction:
in_memory: MappedItemBase
in_db: MappedItemBase

def __init__(self, conflict):
self.in_memory = conflict.in_memory
self.in_db = conflict.in_db


@dataclass
class ResurrectAction:
in_memory: MappedItemBase
in_db: MappedItemBase

def __init__(self, conflict):
self.in_memory = conflict.in_memory
self.in_db = conflict.in_db


def resolved_conflict_actions(conflicts):
for conflict in conflicts:
if conflict.resolution == Resolution.USE_IN_MEMORY:
yield KeepInMemoryAction(conflict)
elif conflict.resolution == Resolution.USE_IN_DB:
yield UpdateInMemoryAction(conflict)
else:
raise RuntimeError(f"unknown conflict resolution")


def resurrection_conflicts_from_resolved(conflicts):
resurrection_conflicts = []
for conflict in conflicts:
if conflict.resolution != Resolution.USE_IN_DB or not conflict.in_memory.removed:
continue
resurrection_conflicts.append(conflict)
return resurrection_conflicts


def make_changed_in_memory_items_dirty(conflicts):
for conflict in conflicts:
if conflict.resolution != Resolution.USE_IN_MEMORY:
continue
if conflict.in_memory.removed:
conflict.in_memory.status = Status.to_remove
elif conflict.in_memory.asdict_() != conflict.in_db:
conflict.in_memory.status = Status.to_update
30 changes: 20 additions & 10 deletions spinedb_api/db_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from alembic.config import Config
from alembic.util.exc import CommandError

from .conflict_resolution import select_in_memory_item_always
from .filters.tools import pop_filter_configs, apply_filter_stack, load_filters
from .spine_db_client import get_db_url_from_server
from .mapped_items import item_factory
Expand Down Expand Up @@ -364,13 +365,16 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs):
return {}
return item.public_item

def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs):
def get_items(
self, item_type, fetch=True, skip_removed=True, resolve_conflicts=select_in_memory_item_always, **kwargs
):
"""Finds and returns all the items of one type.
Args:
item_type (str): One of <spine_item_types>.
fetch (bool, optional): Whether to fetch the DB before returning the items.
skip_removed (bool, optional): Whether to ignore removed items.
resolve_conflicts (Callable): function that resolves fetch conflicts
**kwargs: Fields and values for one the unique keys as specified for the item type
in :ref:`db_mapping_schema`.
Expand All @@ -381,7 +385,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs):
mapped_table = self.mapped_table(item_type)
mapped_table.check_fields(kwargs, valid_types=(type(None),))
if fetch:
self.do_fetch_all(item_type, **kwargs)
self.do_fetch_all(item_type, resolve_conflicts=resolve_conflicts, **kwargs)
get_items = mapped_table.valid_values if skip_removed else mapped_table.values
return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())]

Expand Down Expand Up @@ -617,7 +621,7 @@ def purge_items(self, item_type):
"""
return bool(self.remove_items(item_type, Asterisk))

def fetch_more(self, item_type, offset=0, limit=None, **kwargs):
def fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=select_in_memory_item_always, **kwargs):
"""Fetches items from the DB into the in-memory mapping, incrementally.
Args:
Expand All @@ -631,7 +635,12 @@ def fetch_more(self, item_type, offset=0, limit=None, **kwargs):
list(:class:`PublicItem`): The items fetched.
"""
item_type = self.real_item_type(item_type)
return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit, **kwargs)]
return [
x.public_item
for x in self.do_fetch_more(
item_type, offset=offset, limit=limit, resolve_conflicts=resolve_conflicts, **kwargs
)
]

def fetch_all(self, *item_types):
"""Fetches items from the DB into the in-memory mapping.
Expand Down Expand Up @@ -696,13 +705,18 @@ def commit_session(self, comment):
date = datetime.now(timezone.utc)
ins = self._metadata.tables["commit"].insert()
with self.engine.begin() as connection:
commit_item = {"user": user, "date": date, "comment": comment}
try:
commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0]
commit_id = connection.execute(ins, commit_item).inserted_primary_key[0]
except DBAPIError as e:
raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e
commit_item["id"] = commit_id
commit_table = self.mapped_table("commit")
commit_table.add_item_from_db(commit_item)
commit_item_id = commit_table.id_map.item_id(commit_id)
for tablename, (to_add, to_update, to_remove) in dirty_items:
for item in to_add + to_update + to_remove:
item.commit(commit_id)
item.commit(commit_item_id)
# Remove before add, to help with keeping integrity constraints
self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove})
self._do_update_items(connection, tablename, *to_update)
Expand All @@ -720,10 +734,6 @@ def rollback_session(self):
if self._memory:
self._memory_dirty = False

def refresh_session(self):
"""Resets the fetch status so new items from the DB can be retrieved."""
self._refresh()

def has_external_commits(self):
"""Tests whether the database has had commits from other sources than this mapping.
Expand Down
Loading

0 comments on commit d187d29

Please sign in to comment.