Skip to content

Commit

Permalink
Fix alternative filter
Browse files Browse the repository at this point in the history
Alternative filter was not dropping parameter values of
incomplete multidimensional entities.
  • Loading branch information
soininen committed Nov 22, 2024
1 parent 1564a61 commit b98016f
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 134 deletions.
67 changes: 16 additions & 51 deletions spinedb_api/filters/alternative_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
######################################################################################################################
""" Provides functions to apply filtering based on alternatives to parameter value subqueries. """
from functools import partial
from sqlalchemy import and_, func, or_
from sqlalchemy import and_, or_
from ..exception import SpineDBAPIError
from .query_utils import filter_by_active_elements

ALTERNATIVE_FILTER_TYPE = "alternative_filter"
ALTERNATIVE_FILTER_SHORTHAND_TAG = "alternatives"
Expand Down Expand Up @@ -274,55 +275,19 @@ def _make_alternative_filtered_entity_sq(db_map, state):
Alias: a subquery for entity filtered by selected alternatives
"""
ext_entity_sq = _ext_entity_sq(db_map, state)
ext_entity_element_count_sq = (
db_map.query(
db_map.entity_element_sq.c.entity_id,
func.count(db_map.entity_element_sq.c.element_id).label("element_count"),
)
.group_by(db_map.entity_element_sq.c.entity_id)
.subquery()
)
ext_entity_class_dimension_count_sq = (
db_map.query(
db_map.entity_class_dimension_sq.c.entity_class_id,
func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"),
)
.group_by(db_map.entity_class_dimension_sq.c.entity_class_id)
.subquery()
)
return (
db_map.query(
ext_entity_sq.c.id,
ext_entity_sq.c.class_id,
ext_entity_sq.c.name,
ext_entity_sq.c.description,
ext_entity_sq.c.commit_id,
)
.filter(
or_(
ext_entity_sq.c.active == True,
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
),
)
.outerjoin(
ext_entity_element_count_sq,
ext_entity_element_count_sq.c.entity_id == ext_entity_sq.c.id,
)
.outerjoin(
ext_entity_class_dimension_count_sq,
ext_entity_class_dimension_count_sq.c.entity_class_id == ext_entity_sq.c.class_id,
)
.filter(
or_(
and_(
ext_entity_element_count_sq.c.element_count == None,
ext_entity_class_dimension_count_sq.c.dimension_count == None,
),
ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count,
)
)
.subquery()
filtered_by_activity = db_map.query(
ext_entity_sq.c.id,
ext_entity_sq.c.class_id,
ext_entity_sq.c.name,
ext_entity_sq.c.description,
ext_entity_sq.c.commit_id,
).filter(
or_(
ext_entity_sq.c.active == True,
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
),
)
return filter_by_active_elements(db_map, filtered_by_activity, ext_entity_sq).subquery()


def _make_alternative_filtered_alternative_sq(db_map, state):
Expand Down Expand Up @@ -395,7 +360,7 @@ def _make_alternative_filtered_parameter_value_sq(db_map, state):
"""
subquery = state.original_parameter_value_sq
ext_entity_sq = _ext_entity_sq(db_map, state)
return (
filtered_by_activity = (
db_map.query(subquery)
.filter(subquery.c.alternative_id.in_(state.alternatives))
.filter(subquery.c.entity_id == ext_entity_sq.c.id)
Expand All @@ -405,5 +370,5 @@ def _make_alternative_filtered_parameter_value_sq(db_map, state):
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
)
)
.subquery()
)
return filter_by_active_elements(db_map, filtered_by_activity, ext_entity_sq).subquery()
64 changes: 64 additions & 0 deletions spinedb_api/filters/query_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# Copyright Spine Database API contributors
# This file is part of Spine Database API.
# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
""" Provides utilities for database queries. """
from sqlalchemy import and_, func, or_


def filter_by_active_elements(db_map, query, ext_entity_sq):
"""Applies a filter to given subquery that drops incomplete multidimensional entities.
'Incomplete' means entities that have elements that are inactive,
i.e. are filtered out because entity alternative/active_by_default is set to False.
Args:
db_map (DatabaseMapping): database map
query (Query): query to apply the filter to
ext_entity_sq (Alias): extended entity subquery
Returns:
Alias: filtered subquery
"""
ext_entity_element_count_sq = (
db_map.query(
db_map.entity_element_sq.c.entity_id,
func.count(db_map.entity_element_sq.c.element_id).label("element_count"),
)
.group_by(db_map.entity_element_sq.c.entity_id)
.subquery()
)
ext_entity_class_dimension_count_sq = (
db_map.query(
db_map.entity_class_dimension_sq.c.entity_class_id,
func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"),
)
.group_by(db_map.entity_class_dimension_sq.c.entity_class_id)
.subquery()
)
return (
query.outerjoin(
ext_entity_element_count_sq,
ext_entity_element_count_sq.c.entity_id == ext_entity_sq.c.id,
)
.outerjoin(
ext_entity_class_dimension_count_sq,
ext_entity_class_dimension_count_sq.c.entity_class_id == ext_entity_sq.c.class_id,
)
.filter(
or_(
and_(
ext_entity_element_count_sq.c.element_count == None,
ext_entity_class_dimension_count_sq.c.dimension_count == None,
),
ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count,
)
)
)
99 changes: 16 additions & 83 deletions spinedb_api/filters/scenario_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import partial
from sqlalchemy import and_, desc, func, or_
from ..exception import SpineDBAPIError
from .query_utils import filter_by_active_elements

SCENARIO_FILTER_TYPE = "scenario_filter"
SCENARIO_SHORTHAND_TAG = "scenario"
Expand Down Expand Up @@ -269,56 +270,20 @@ def _make_scenario_filtered_entity_sq(db_map, state):
Alias: a subquery for entity filtered by selected scenario
"""
ext_entity_sq = _ext_entity_sq(db_map, state)
ext_entity_element_count_sq = (
db_map.query(
db_map.entity_element_sq.c.entity_id,
func.count(db_map.entity_element_sq.c.element_id).label("element_count"),
)
.group_by(db_map.entity_element_sq.c.entity_id)
.subquery()
)
ext_entity_class_dimension_count_sq = (
db_map.query(
db_map.entity_class_dimension_sq.c.entity_class_id,
func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"),
)
.group_by(db_map.entity_class_dimension_sq.c.entity_class_id)
.subquery()
)
return (
db_map.query(
ext_entity_sq.c.id,
ext_entity_sq.c.class_id,
ext_entity_sq.c.name,
ext_entity_sq.c.description,
ext_entity_sq.c.commit_id,
)
.filter(
ext_entity_sq.c.desc_rank_row_number == 1,
or_(
ext_entity_sq.c.active == True,
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
),
)
.outerjoin(
ext_entity_element_count_sq,
ext_entity_element_count_sq.c.entity_id == ext_entity_sq.c.id,
)
.outerjoin(
ext_entity_class_dimension_count_sq,
ext_entity_class_dimension_count_sq.c.entity_class_id == ext_entity_sq.c.class_id,
)
.filter(
or_(
and_(
ext_entity_element_count_sq.c.element_count == None,
ext_entity_class_dimension_count_sq.c.dimension_count == None,
),
ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count,
)
)
.subquery()
filtered_by_activity = db_map.query(
ext_entity_sq.c.id,
ext_entity_sq.c.class_id,
ext_entity_sq.c.name,
ext_entity_sq.c.description,
ext_entity_sq.c.commit_id,
).filter(
ext_entity_sq.c.desc_rank_row_number == 1,
or_(
ext_entity_sq.c.active == True,
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
),
)
return filter_by_active_elements(db_map, filtered_by_activity, ext_entity_sq).subquery()


def _make_scenario_filtered_entity_alternative_sq(db_map, state):
Expand Down Expand Up @@ -360,22 +325,6 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state):
Alias: a subquery for parameter value filtered by selected scenario
"""
ext_entity_sq = _ext_entity_sq(db_map, state)
ext_entity_element_count_sq = (
db_map.query(
db_map.entity_element_sq.c.entity_id,
func.count(db_map.entity_element_sq.c.element_id).label("element_count"),
)
.group_by(db_map.entity_element_sq.c.entity_id)
.subquery()
)
ext_entity_class_dimension_count_sq = (
db_map.query(
db_map.entity_class_dimension_sq.c.entity_class_id,
func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"),
)
.group_by(db_map.entity_class_dimension_sq.c.entity_class_id)
.subquery()
)
ext_parameter_value_sq = (
db_map.query(
state.original_parameter_value_sq,
Expand All @@ -392,7 +341,7 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state):
.filter(state.original_parameter_value_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id)
.filter(db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id)
).subquery()
return (
filtered_by_entity_activity = (
db_map.query(ext_parameter_value_sq)
.filter(ext_parameter_value_sq.c.desc_rank_row_number == 1)
.filter(ext_parameter_value_sq.c.entity_id == ext_entity_sq.c.id)
Expand All @@ -403,24 +352,8 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state):
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
),
)
.outerjoin(
ext_entity_element_count_sq, ext_entity_element_count_sq.c.entity_id == ext_parameter_value_sq.c.entity_id
)
.outerjoin(
ext_entity_class_dimension_count_sq,
ext_entity_class_dimension_count_sq.c.entity_class_id == ext_parameter_value_sq.c.entity_class_id,
)
.filter(
or_(
and_(
ext_entity_element_count_sq.c.element_count == None,
ext_entity_class_dimension_count_sq.c.dimension_count == None,
),
ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count,
)
)
.subquery()
)
return filter_by_active_elements(db_map, filtered_by_entity_activity, ext_entity_sq).subquery()


def _make_scenario_filtered_alternative_sq(db_map, state):
Expand Down
52 changes: 52 additions & 0 deletions tests/filters/test_alternative_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,58 @@ def test_filters_parameters_values_by_active_by_default(self):
self.assertEqual(len(values), 1)
self.assertEqual(from_database(values[0]["value"], values[0]["type"]), -2.3)

def test_filters_parameter_values_of_multidimensional_entities_inactive_by_elements(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self._assert_success(db_map.add_entity_class_item(name="Object"))
self._assert_success(db_map.add_entity_item(name="invisible", entity_class_name="Object"))
self._assert_success(
db_map.add_entity_alternative_item(
entity_class_name="Object", entity_byname=("invisible",), alternative_name="Base", active=False
)
)
self._assert_success(db_map.add_entity_item(name="visible", entity_class_name="Object"))
self._assert_success(
db_map.add_entity_alternative_item(
entity_class_name="Object", entity_byname=("visible",), alternative_name="Base", active=True
)
)
self._assert_success(db_map.add_entity_class_item(name="Relationship", dimension_name_list=("Object",)))
self._assert_success(
db_map.add_entity_item(element_name_list=("invisible",), entity_class_name="Relationship")
)
self._assert_success(
db_map.add_entity_item(element_name_list=("visible",), entity_class_name="Relationship")
)
self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="Relationship"))
value, value_type = to_database(2.3)
self._assert_success(
db_map.add_parameter_value_item(
entity_class_name="Relationship",
entity_byname=("invisible",),
parameter_definition_name="y",
alternative_name="Base",
value=value,
type=value_type,
)
)
value, value_type = to_database(-2.3)
self._assert_success(
db_map.add_parameter_value_item(
entity_class_name="Relationship",
entity_byname=("visible",),
parameter_definition_name="y",
alternative_name="Base",
value=value,
type=value_type,
)
)
db_map.commit_session("Add values.")
config = alternative_filter_config(["Base"])
alternative_filter_from_dict(db_map, config)
values = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(values), 1)
self.assertEqual(from_database(values[0]["value"], values[0]["type"]), -2.3)

def _build_data_without_alternatives(self, db_map, commit=True):
self._assert_imports(import_entity_classes(db_map, ["object_class"]))
self._assert_imports(import_entities(db_map, [("object_class", "object")]))
Expand Down

0 comments on commit b98016f

Please sign in to comment.