Skip to content

Commit

Permalink
Fix importing pivoted data with skip rows
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen committed Oct 11, 2024
1 parent 9dd944e commit 5a7c843
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
35 changes: 26 additions & 9 deletions spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
"""
Contains `get_mapped_data()` that converts rows of tabular data into a dictionary for import to a Spine DB,
using ``import_functions.import_data()``
"""

from copy import deepcopy
from operator import itemgetter
from ..exception import ParameterValueFormatError
from ..helpers import string_to_bool
from ..mapping import Position
from ..mapping import Position, is_pivoted
from ..parameter_value import (
Array,
Map,
Expand Down Expand Up @@ -97,36 +96,47 @@ def get_mapped_data(
mapping_names = []
_ensure_mapping_name_consistency(mappings, mapping_names)
for mapping, mapping_name in zip(mappings, mapping_names):
read_state = {}
mapping = deepcopy(mapping)
mapping.polish(table_name, data_header, mapping_name, column_count)
mapping_errors = check_validity(mapping)
if mapping_errors:
errors += mapping_errors
continue
read_state = {}
# Find pivoted and unpivoted mappings
pivoted, non_pivoted, pivoted_from_header, last = _split_mapping(mapping)
# If there are no pivoted mappings, we can just feed the rows to our mapping directly
if not (pivoted or pivoted_from_header):
start_pos = mapping.read_start_row
for k, row in enumerate(rows[mapping.read_start_row :]):
for k, row in enumerate(rows[start_pos:]):
if not _is_valid_row(row):
continue
row = _convert_row(row, column_convert_fns, start_pos + k, errors)
mapping.import_row(row, read_state, mapped_data)
continue
# There are pivoted mappings. We unpivot the table
pivoted_by_leaf = all(
not is_pivoted(m.position) and m.position != Position.header for m in mapping.flatten()[:-1]
)
unpivoted_rows, pivoted_pos, non_pivoted_pos, unpivoted_column_pos = _unpivot_rows(
rows, data_header, pivoted, non_pivoted, pivoted_from_header, mapping.skip_columns
rows,
data_header,
pivoted,
non_pivoted,
pivoted_from_header,
mapping.skip_columns,
mapping.read_start_row,
pivoted_by_leaf,
)
if not unpivoted_column_pos:
continue
if not is_pivoted(last.position):
last.position = -1
# Reposition row convert functions
row_convert_fns = {k: row_convert_fns[pos] for k, pos in enumerate(pivoted_pos) if pos in row_convert_fns}
# If there are only pivoted mappings, we can just feed the unpivoted rows
if not non_pivoted:
# Reposition pivoted mappings:
last.position = -1
for k, m in enumerate(pivoted):
m.position = k
for k, row in enumerate(unpivoted_rows):
Expand All @@ -140,7 +150,6 @@ def get_mapped_data(
# - The last mapping (typically, parameter value) will read from the last position in the row
# - The pivoted mappings will read from positions to the left of that
k = None
last.position = -1
for k, m in enumerate(reversed(pivoted)):
m.position = -(k + 2)
# Feed rows: To each regular row, we append each unpivoted row, plus the item at the intersection,
Expand Down Expand Up @@ -226,7 +235,9 @@ def _split_mapping(mapping):
return pivoted, non_pivoted, pivoted_from_header, flattened[-1]


def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header, skip_columns):
def _unpivot_rows(
rows, data_header, pivoted, non_pivoted, pivoted_from_header, skip_columns, read_start_row, pivoted_by_leaf
):
"""Unpivots rows.
Args:
Expand All @@ -235,6 +246,9 @@ def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header,
pivoted (list of ImportMapping): Pivoted mappings (reading from rows)
non_pivoted (list of ImportMapping): Non-pivoted mappings ('regular', reading from columns)
pivoted_from_header (list of ImportMapping): Mappings pivoted from header
skip_columns (list of int): columns that should be skipped
read_start_row (int): first row to include
pivoted_by_leaf (bool): whether only the leaf mapping is pivoted
Returns:
list of list: Unpivoted rows
Expand All @@ -246,7 +260,10 @@ def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header,
pivoted_pos = [-(m.position + 1) for m in pivoted] # (-1) -> (0), (-2) -> (1), (-3) -> (2), etc.
non_pivoted_pos = [m.position for m in non_pivoted]
# Collect pivoted rows
pivoted_rows = [rows[pos] for pos in pivoted_pos] if non_pivoted_pos else rows
if not pivoted_by_leaf:
pivoted_rows = [rows[pos] for pos in pivoted_pos] if non_pivoted_pos else rows
else:
pivoted_rows = [rows[pos + read_start_row] for pos in pivoted_pos] if non_pivoted_pos else rows[read_start_row:]
# Prepend as many headers as needed
for m in pivoted_from_header:
pivoted.insert(0, m)
Expand Down
22 changes: 22 additions & 0 deletions tests/import_mapping/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,28 @@ def test_import_parameter_types(self):
},
)

def test_skip_first_row_when_importing_pivoted_data(self):
data_source = iter(
[
[None, "alternative1", "alternative2", "alternative3"],
["Scenario1", "Base", "fixed_prices", None],
]
)
mappings = [
[
{"map_type": "Scenario", "position": 0, "read_start_row": 1},
{"map_type": "ScenarioAlternative", "position": -1},
]
]
convert_function_specs = {0: "string", 1: "string", 2: "string", 3: "string"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}
mapped_data, errors = get_mapped_data(data_source, mappings, column_convert_fns=convert_functions)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{"scenario_alternatives": [["Scenario1", "Base"], ["Scenario1", "fixed_prices"]]},
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5a7c843

Please sign in to comment.