From 11360a0b3ed4eea5b410a58a04d514ea14db05eb Mon Sep 17 00:00:00 2001 From: smathot Date: Thu, 21 Nov 2024 16:49:10 +0100 Subject: [PATCH] Fix an issue in referencing multidimensional columns by names with dict notation --- datamatrix/_datamatrix/_datamatrix.py | 2 ++ datamatrix/_datamatrix/_multidimensionalcolumn.py | 11 +++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/datamatrix/_datamatrix/_datamatrix.py b/datamatrix/_datamatrix/_datamatrix.py index 477202d..d87a76c 100644 --- a/datamatrix/_datamatrix/_datamatrix.py +++ b/datamatrix/_datamatrix/_datamatrix.py @@ -343,6 +343,8 @@ def _slice(self, key): # and row. Therefore, we turn tuples into lists. if isinstance(key, tuple): key = list(key) + if isinstance(key[0], str): + return self._getcolbyname(key[0])[key[1:]] _rowid = self._rowid[key] dm = DataMatrix(len(_rowid)) object.__setattr__(dm, u'_rowid', _rowid) diff --git a/datamatrix/_datamatrix/_multidimensionalcolumn.py b/datamatrix/_datamatrix/_multidimensionalcolumn.py index 16dd31f..2ca3354 100644 --- a/datamatrix/_datamatrix/_multidimensionalcolumn.py +++ b/datamatrix/_datamatrix/_multidimensionalcolumn.py @@ -23,10 +23,7 @@ from datamatrix import cfg from datamatrix._datamatrix._numericcolumn import NumericColumn, FloatColumn from datamatrix._datamatrix._datamatrix import DataMatrix -try: - from collections.abc import Sequence # Python 3.3 and later -except ImportError: - from collections import Sequence +from collections.abc import Sequence, Collection from collections import OrderedDict try: import numpy as np @@ -103,7 +100,9 @@ def __init__(self, datamatrix, shape, defaultnan=True, **kwargs): normshape += (dim_size, ) self.index_names.append(list(range(dim_size))) self.index_values.append(list(range(dim_size))) - else: + elif isinstance(dim_size, Collection): + if isinstance(dim_size, str): + raise ValueError('A dimension cannot be a string') normshape += (len(dim_size), ) self.index_names.append(list(dim_size)) self.index_values.append(list(range(len(dim_size)))) @@ -417,7 +416,7 @@ def _getintkey(self, key): def __getitem__(self, key): touch_history.touch(self, try_to_load=True) - if isinstance(key, tuple) and len(key) <= len(self._seq.shape): + if isinstance(key, (tuple, list)) and len(key) <= len(self._seq.shape): # Advanced indexing always returns a copy, rather than a view, so # there's no need to explicitly copy the result. indices = self._numindices(key, accept_ellipsis=True)