Skip to content

Commit

Permalink
Fix importing pivoted data with read_start_row (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Oct 17, 2024
2 parents 5ad4bd3 + 9d1c763 commit 6eb6e26
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 23 deletions.
13 changes: 13 additions & 0 deletions spinedb_api/export_mapping/export_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,12 +1309,25 @@ class ScenarioAlternativeMapping(ExportMapping):
MAP_TYPE = "ScenarioAlternative"

def add_query_columns(self, db_map, query):
if self._child is None:
return query.add_columns(
db_map.ext_scenario_sq.c.alternative_id,
db_map.ext_scenario_sq.c.alternative_name,
db_map.ext_scenario_sq.c.rank,
)
# Legacy: expecting child to be ScenarioBeforeAlternativeMapping
return query.add_columns(
db_map.ext_linked_scenario_alternative_sq.c.alternative_id,
db_map.ext_linked_scenario_alternative_sq.c.alternative_name,
)

def filter_query(self, db_map, query):
if self._child is None:
return query.outerjoin(
db_map.ext_scenario_sq,
db_map.ext_scenario_sq.c.id == db_map.scenario_sq.c.id,
).order_by(db_map.ext_scenario_sq.c.name, db_map.ext_scenario_sq.c.rank)
# Legacy: expecting child to be ScenarioBeforeAlternativeMapping
return query.outerjoin(
db_map.ext_linked_scenario_alternative_sq,
db_map.ext_linked_scenario_alternative_sq.c.scenario_id == db_map.scenario_sq.c.id,
Expand Down
5 changes: 1 addition & 4 deletions spinedb_api/export_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Contains generator functions that convert a Spine database into rows of tabular data.
"""
""" Contains generator functions that convert a Spine database into rows of tabular data. """
from copy import deepcopy
from ..mapping import Position
from .export_mapping import pair_header_buddies
Expand Down
9 changes: 2 additions & 7 deletions spinedb_api/export_mapping/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,24 +333,19 @@ def scenario_export(
return scenario_mapping


def scenario_alternative_export(
scenario_position=Position.hidden, alternative_position=Position.hidden, before_alternative_position=Position.hidden
):
def scenario_alternative_export(scenario_position=Position.hidden, alternative_position=Position.hidden):
"""
Sets up export mappings for exporting scenario alternatives.
Args:
scenario_position (int or Position): position of scenarios
alternative_position (int or Position): position of alternatives
before_alternative_position (int or Position): position of 'before' alternatives
(for each row, the 'alternative' goes *before* the 'before alternative' in the scenario rank)
Returns:
Mapping: root mapping
"""
scenario_mapping = ScenarioMapping(scenario_position)
alternative_mapping = scenario_mapping.child = ScenarioAlternativeMapping(alternative_position)
alternative_mapping.child = ScenarioBeforeAlternativeMapping(before_alternative_position)
scenario_mapping.child = ScenarioAlternativeMapping(alternative_position)
return scenario_mapping


Expand Down
35 changes: 26 additions & 9 deletions spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
"""
Contains `get_mapped_data()` that converts rows of tabular data into a dictionary for import to a Spine DB,
using ``import_functions.import_data()``
"""

from copy import deepcopy
from operator import itemgetter
from ..exception import ParameterValueFormatError
from ..helpers import string_to_bool
from ..mapping import Position
from ..mapping import Position, is_pivoted
from ..parameter_value import (
Array,
Map,
Expand Down Expand Up @@ -97,36 +96,47 @@ def get_mapped_data(
mapping_names = []
_ensure_mapping_name_consistency(mappings, mapping_names)
for mapping, mapping_name in zip(mappings, mapping_names):
read_state = {}
mapping = deepcopy(mapping)
mapping.polish(table_name, data_header, mapping_name, column_count)
mapping_errors = check_validity(mapping)
if mapping_errors:
errors += mapping_errors
continue
read_state = {}
# Find pivoted and unpivoted mappings
pivoted, non_pivoted, pivoted_from_header, last = _split_mapping(mapping)
# If there are no pivoted mappings, we can just feed the rows to our mapping directly
if not (pivoted or pivoted_from_header):
start_pos = mapping.read_start_row
for k, row in enumerate(rows[mapping.read_start_row :]):
for k, row in enumerate(rows[start_pos:]):
if not _is_valid_row(row):
continue
row = _convert_row(row, column_convert_fns, start_pos + k, errors)
mapping.import_row(row, read_state, mapped_data)
continue
# There are pivoted mappings. We unpivot the table
pivoted_by_leaf = all(
not is_pivoted(m.position) and m.position != Position.header for m in mapping.flatten()[:-1]
)
unpivoted_rows, pivoted_pos, non_pivoted_pos, unpivoted_column_pos = _unpivot_rows(
rows, data_header, pivoted, non_pivoted, pivoted_from_header, mapping.skip_columns
rows,
data_header,
pivoted,
non_pivoted,
pivoted_from_header,
mapping.skip_columns,
mapping.read_start_row,
pivoted_by_leaf,
)
if not unpivoted_column_pos:
continue
if not is_pivoted(last.position):
last.position = -1
# Reposition row convert functions
row_convert_fns = {k: row_convert_fns[pos] for k, pos in enumerate(pivoted_pos) if pos in row_convert_fns}
# If there are only pivoted mappings, we can just feed the unpivoted rows
if not non_pivoted:
# Reposition pivoted mappings:
last.position = -1
for k, m in enumerate(pivoted):
m.position = k
for k, row in enumerate(unpivoted_rows):
Expand All @@ -140,7 +150,6 @@ def get_mapped_data(
# - The last mapping (typically, parameter value) will read from the last position in the row
# - The pivoted mappings will read from positions to the left of that
k = None
last.position = -1
for k, m in enumerate(reversed(pivoted)):
m.position = -(k + 2)
# Feed rows: To each regular row, we append each unpivoted row, plus the item at the intersection,
Expand Down Expand Up @@ -226,7 +235,9 @@ def _split_mapping(mapping):
return pivoted, non_pivoted, pivoted_from_header, flattened[-1]


def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header, skip_columns):
def _unpivot_rows(
rows, data_header, pivoted, non_pivoted, pivoted_from_header, skip_columns, read_start_row, pivoted_by_leaf
):
"""Unpivots rows.
Args:
Expand All @@ -235,6 +246,9 @@ def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header,
pivoted (list of ImportMapping): Pivoted mappings (reading from rows)
non_pivoted (list of ImportMapping): Non-pivoted mappings ('regular', reading from columns)
pivoted_from_header (list of ImportMapping): Mappings pivoted from header
skip_columns (list of int): columns that should be skipped
read_start_row (int): first row to include
pivoted_by_leaf (bool): whether only the leaf mapping is pivoted
Returns:
list of list: Unpivoted rows
Expand All @@ -246,7 +260,10 @@ def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header,
pivoted_pos = [-(m.position + 1) for m in pivoted] # (-1) -> (0), (-2) -> (1), (-3) -> (2), etc.
non_pivoted_pos = [m.position for m in non_pivoted]
# Collect pivoted rows
pivoted_rows = [rows[pos] for pos in pivoted_pos] if non_pivoted_pos else rows
if not pivoted_by_leaf:
pivoted_rows = [rows[pos] for pos in pivoted_pos] if non_pivoted_pos else rows
else:
pivoted_rows = [rows[pos + read_start_row] for pos in pivoted_pos] if non_pivoted_pos else rows[read_start_row:]
# Prepend as many headers as needed
for m in pivoted_from_header:
pivoted.insert(0, m)
Expand Down
3 changes: 1 addition & 2 deletions spinedb_api/import_mapping/import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,8 +974,7 @@ def _default_scenario_alternative_mapping():
ScenarioAlternativeMapping: root mapping
"""
root_mapping = ScenarioMapping(Position.hidden)
scen_alt_mapping = root_mapping.child = ScenarioAlternativeMapping(Position.hidden)
scen_alt_mapping.child = ScenarioBeforeAlternativeMapping(Position.hidden)
root_mapping.child = ScenarioAlternativeMapping(Position.hidden)
return root_mapping


Expand Down
34 changes: 33 additions & 1 deletion tests/export_mapping/test_export_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
ParameterValueMapping,
ParameterValueTypeMapping,
ScenarioAlternativeMapping,
ScenarioBeforeAlternativeMapping,
ScenarioDescriptionMapping,
ScenarioMapping,
drop_non_positioned_tail,
Expand Down Expand Up @@ -972,7 +973,38 @@ def test_scenario_alternative_mapping(self):
tables = {}
for title, title_key in titles(scenario_mapping, db_map):
tables[title] = list(rows(scenario_mapping, db_map, title_key))
self.assertEqual(tables, {None: [["s1", "a1"], ["s1", "a2"], ["s2", "a2"], ["s2", "a3"]]})
self.assertEqual(tables, {None: [["s1", "a1"], ["s1", "a2"], ["s2", "a3"], ["s2", "a2"]]})
db_map.close()

def test_scenario_alternative_mapping_exports_alternatives_in_correct_order(self):
db_map = DatabaseMapping("sqlite://", create=True)
import_alternatives(db_map, ("a1", "a2", "a3"))
import_scenarios(db_map, ("s1",))
import_scenario_alternatives(db_map, (("s1", "a2"), ("s1", "a1", "a2"), ("s1", "a3", "a2")))
db_map.commit_session("Add test data.")
scenario_mapping = ScenarioMapping(0)
scenario_alternative_mapping = ScenarioAlternativeMapping(1)
scenario_mapping.child = scenario_alternative_mapping
tables = {}
for title, title_key in titles(scenario_mapping, db_map):
tables[title] = list(rows(scenario_mapping, db_map, title_key))
self.assertEqual(tables, {None: [["s1", "a1"], ["s1", "a3"], ["s1", "a2"]]})
db_map.close()

def test_legacy_scenario_alternative_mapping_with_before_alternatives(self):
db_map = DatabaseMapping("sqlite://", create=True)
import_alternatives(db_map, ("a1", "a2", "a3"))
import_scenarios(db_map, ("s1",))
import_scenario_alternatives(db_map, (("s1", "a2"), ("s1", "a1", "a2"), ("s1", "a3", "a2")))
db_map.commit_session("Add test data.")
scenario_mapping = ScenarioMapping(0)
scenario_alternative_mapping = ScenarioAlternativeMapping(1)
scenario_mapping.child = scenario_alternative_mapping
scenario_alternative_mapping.child = ScenarioBeforeAlternativeMapping(2)
tables = {}
for title, title_key in titles(scenario_mapping, db_map):
tables[title] = list(rows(scenario_mapping, db_map, title_key))
self.assertEqual(tables, {None: [["s1", "a1", "a3"], ["s1", "a3", "a2"]]})
db_map.close()

def test_header(self):
Expand Down
22 changes: 22 additions & 0 deletions tests/import_mapping/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,28 @@ def test_import_parameter_types(self):
},
)

def test_skip_first_row_when_importing_pivoted_data(self):
data_source = iter(
[
[None, "alternative1", "alternative2", "alternative3"],
["Scenario1", "Base", "fixed_prices", None],
]
)
mappings = [
[
{"map_type": "Scenario", "position": 0, "read_start_row": 1},
{"map_type": "ScenarioAlternative", "position": -1},
]
]
convert_function_specs = {0: "string", 1: "string", 2: "string", 3: "string"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}
mapped_data, errors = get_mapped_data(data_source, mappings, column_convert_fns=convert_functions)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{"scenario_alternatives": [["Scenario1", "Base"], ["Scenario1", "fixed_prices"]]},
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6eb6e26

Please sign in to comment.