Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More powerful setlabels #985

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion doc/source/changes/version_0_34.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Syntax changes
^^^^^^^^^^^^^^

* renamed ``Array.old_method_name()`` to :py:obj:`Array.new_method_name()` (closes :issue:`1`).
* renamed ``Axis.apply()`` and ``Axis.replace()`` are deprecated in favor of :py:obj:`Axis.set_labels()`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/renamed//

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are Axis.apply() and/or Axis.replace()still mentioned in the api.rst file?
If yes, please remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You made me realize I forgot to update api.rst for this PR. So yes, they are still mentioned and Axis.set_labels is not.


* renamed ``old_argument_name`` argument of :py:obj:`Array.method_name()` to ``new_argument_name``.

Expand Down Expand Up @@ -52,6 +52,24 @@ Miscellaneous improvements
* made all I/O functions/methods/constructors to accept either a string or a pathlib.Path object
for all arguments representing a path (closes :issue:`896`).

* :py:obj:`Array.set_labels()` and :py:obj:`Axis.set_labels()` (formerly ``Axis.replace()`` and ``Axis.apply()``) now
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check if either Axis.replace() or Axis.apply() was used in the tutorial?
If yes, replace it by Axis.set_labels() please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, I forgot about the tutorial.

accepts slices, Groups or selection strings as labels to change and callable and "creation strings" as new labels, so
that it is easier to change only a subset of labels or to change several labels in the same way (closes :issue:`906`).

>>> arr = ndtest((2, 3))
>>> arr
a\b b0 b1 b2
a0 0 1 2
a1 3 4 5
>>> arr.set_labels({'b1:': str.upper, 'a1': 'A-ONE'})
a\b b0 B1 B2
a0 0 1 2
A-ONE 3 4 5
>>> arr.set_labels('b1:', 'B1..B2')
a\b b0 B1 B2
a0 0 1 2
a1 3 4 5

* added type hints for all remaining functions and methods which improves autocompletion in editors (such as PyCharm).
Closes :issue:`864`.

Expand Down
12 changes: 8 additions & 4 deletions larray/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7424,7 +7424,6 @@ def __array__(self, dtype=None):

__array_priority__ = 100

# TODO: this should be a thin wrapper around a method in AxisCollection
def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'Array':
r"""Replaces the labels of one or several axes of the array.

Expand Down Expand Up @@ -7522,13 +7521,18 @@ def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'Array'
nat\sex Men F
Belgian 0 1
FO 2 3

>>> a.set_labels({'M:F': str.lower, 'BE': 'Belgian', 'FO': 'Foreigner'})
nat\sex m f
Belgian 0 1
Foreigner 2 3
"""
axes = self.axes.set_labels(axis, labels, **kwargs)
new_axes = self.axes.set_labels(axis, labels, **kwargs)
if inplace:
self.axes = axes
self.axes = new_axes
return self
else:
return Array(self.data, axes)
return Array(self.data, new_axes)

def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True) -> 'Array':
return Array(self.data.astype(dtype, order, casting, subok, copy), self.axes)
Expand Down
168 changes: 90 additions & 78 deletions larray/core/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

from larray.core.abstractbases import ABCAxis, ABCAxisReference, ABCArray
from larray.core.expr import ExprNode
from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_tick, _to_ticks, _to_key, _seq_summary,
_idx_seq_to_slice, _seq_group_to_name, _translate_group_key_hdf, remove_nested_groups)
from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_label, _to_labels, _to_key, _seq_summary,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe mention in the commit message that you renamed tick as label?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx

_idx_seq_to_slice, _seq_group_to_name, _translate_group_key_hdf, remove_nested_groups,
_to_label_or_labels)
from larray.util.oset import OrderedSet
from larray.util.misc import (duplicates, array_lookup2, ReprString, index_by_id, renamed_to, common_type, LHDFStore,
lazy_attribute, _isnoneslice, unique_list, unique_multi, Product, argsort, has_duplicates,
Expand Down Expand Up @@ -195,7 +196,7 @@ def labels(self, labels):
labels = np.arange(length)
iswildcard = True
else:
labels = _to_ticks(labels, parse_single_int=True)
labels = _to_labels(labels, parse_single_int=True)
length = len(labels)
iswildcard = False

Expand Down Expand Up @@ -883,7 +884,7 @@ def _ipython_key_completions_(self) -> List[Scalar]:

def __contains__(self, key) -> bool:
# TODO: ideally, _to_tick shouldn't be necessary, the __hash__ and __eq__ of Group should include this
return _to_tick(key) in self._mapping
return _to_label(key) in self._mapping

# use the default hash. We have to specify it explicitly because we define __eq__
__hash__ = object.__hash__
Expand All @@ -905,7 +906,7 @@ def index(self, key) -> Union[int, np.ndarray, slice]:

Returns
-------
(array of) int
int, slice, np.ndarray or Arrray
Numerical index(ices) of (all) label(s) represented by the key

Notes
Expand All @@ -919,14 +920,17 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
3
>>> people.index(people.containing('Bruce'))
array([1, 2])
>>> a = Axis('a0..a5', 'a')
>>> a.index('a1,a3,a2..a4')
array([1, 3, 2, 3, 4])
"""
mapping = self._mapping

if isinstance(key, Group) and key.axis is not self and key.axis is not None:
try:
# XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last
# resort
potential_tick = _to_tick(key)
potential_tick = _to_label(key)
# avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test
if self._is_key_type_compatible(potential_tick):
return mapping[potential_tick]
Expand Down Expand Up @@ -1121,73 +1125,91 @@ def copy(self) -> 'Axis':
new_axis.__sorted_values = self.__sorted_values
return new_axis

def replace(self, old, new=None) -> 'Axis':
def set_labels(self, old_or_changes, new=None) -> 'Axis':
r"""
Returns a new axis with some labels replaced.
Returns a new axis with some labels changed.

Parameters
----------
old : any scalar (bool, int, str, ...), tuple/list/array of scalars, or a mapping.
the label(s) to be replaced. Old can be a mapping {old1: new1, old2: new2, ...}
new : any scalar (bool, int, str, ...) or tuple/list/array of scalars, optional
the new label(s). This is argument must not be used if old is a mapping.
It supports three distinct syntax variants:

Returns
-------
Axis
a new Axis with the old labels replaced by new labels.
* Axis.set_labels(new_labels) -> replace all Axis labels by `new_labels`
* Axis.set_labels(label_selection, new_labels) -> replace selection of labels by `new_labels`
* Axis.set_labels({old1: new1, old2: new2}) -> replace each selection of labels by corresponding new labels

Examples
--------
>>> sex = Axis('sex=M,F')
>>> sex
Axis(['M', 'F'], 'sex')
>>> sex.replace('M', 'Male')
Axis(['Male', 'F'], 'sex')
>>> sex.replace({'M': 'Male', 'F': 'Female'})
Axis(['Male', 'Female'], 'sex')
>>> sex.replace(['M', 'F'], ['Male', 'Female'])
Axis(['Male', 'Female'], 'sex')
"""
if isinstance(old, dict):
new = list(old.values())
old = list(old.keys())
elif np.isscalar(old):
assert new is not None and np.isscalar(new), f"{new} is not a scalar but a {type(new).__name__}"
old = [old]
new = [new]
else:
seq = (tuple, list, np.ndarray)
assert isinstance(old, seq), f"{old} is not a sequence but a {type(old).__name__}"
assert isinstance(new, seq), f"{new} is not a sequence but a {type(new).__name__}"
assert len(old) == len(new)
# using object dtype because new labels length can be larger than the fixed str length in the self.labels array
labels = self.labels.astype(object)
indices = self.index(old)
labels[indices] = new
return Axis(labels, self.name)

def apply(self, func) -> 'Axis':
r"""
Returns a new axis with the labels transformed by func.
Additionally, new labels in any of the above forms can be a function which transforms the existing
labels to produce the actual new labels.

Parameters
----------
func : callable
A callable which takes a single argument and returns a single value.
old_or_changes : any scalar (bool, int, str, ...), tuple/list/array of scalars, Group, callable or mapping.
This can be either:

* A selection of label(s) to be replaced. This can take several forms:
- a single label (e.g. 'France')
- a list of labels (e.g. ['France', 'Germany'])
- a comma-separated string of labels (e.g. 'France,Germany')
- a Group (e.g. country['France'])
* A mapping {selection1: new_labels1, selection2: new_labels2, ...}
* New labels, in which case all the axis labels will be replaced by these new labels and
the `new` argument must not be used.
new : any scalar (bool, int, str, ...) or tuple/list/array of scalars or callable, optional
The new label(s) or function to apply to old labels to get the new labels. This is argument must not be
used if `old_or_changes` contains the new labels or if it is a mapping.

Returns
-------
Axis
a new Axis with the transformed labels.
a new Axis with the old labels replaced by new labels.

Examples
--------
>>> sex = Axis('sex=MALE,FEMALE')
>>> sex.apply(str.capitalize)
Axis(['Male', 'Female'], 'sex')
"""
return Axis(np_frompyfunc(func, 1, 1)(self.labels), self.name)
>>> country = Axis('country=be,de,fr')
>>> country
Axis(['be', 'de', 'fr'], 'country')
>>> country.set_labels('be', 'Belgium')
Axis(['Belgium', 'de', 'fr'], 'country')
>>> country.set_labels({'de': 'Germany', 'fr': 'France'})
Axis(['be', 'Germany', 'France'], 'country')
>>> country.set_labels(['be', 'fr'], ['Belgium', 'France'])
Axis(['Belgium', 'de', 'France'], 'country')
>>> country.set_labels('be,de', 'Belgium-Germany')
Axis(['Belgium-Germany', 'Belgium-Germany', 'fr'], 'country')
>>> country.set_labels('be,de', ['Belgium', 'Germany'])
Axis(['Belgium', 'Germany', 'fr'], 'country')
>>> country.set_labels(str.upper)
Axis(['BE', 'DE', 'FR'], 'country')
"""
# FIXME: compute max(length of new keys and old labels array) instead
# XXX: it might be easier to go via list to get the label type auto-detection
# labels = self.labels.tolist()

# using object dtype because new labels length can be larger than the fixed str length in self.labels
labels = self.labels.astype(object)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the labels of the returned axis be always of the type object?
What about non-string labels (e.g. int) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they will. That's clearly suboptimal but already better than broken labels IMO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, maybe store the dtype in the beginning of the method and re-apply astype() in the end in original type was not str?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not as easy, because we can easily get mixed types labels in the result (especially when making changes to only a subset of labels) even when the original dtype is not str. Imagine applying str.format() or whatever function which converts integers to strings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one gets mixed types labels in the result, shouldn't we force her/him to first convert the type of the labels to str?
I'm afraid that the only case where users will get mixed types labels is when they don't realize they do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting everything to string is painful/surprising too. The usual case where you want a mixed type axis is when you have an integer axis (e.g. age) and you add a "total" label, or the opposite: you have a string axis with some special aggregate labels ("total", etc.) and want to convert all "number strings" to integers, but not the special labels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usual case where you want a mixed type axis is when you have an integer axis (e.g. age) and you add a "total" label

Yes indeed. I forgot that specific case. I guess the current implementation is OK then.

get_indices = self.index

def apply_changes(selection, label_change):
old_indices = get_indices(selection)
if callable(label_change):
old_labels = labels[old_indices]
if isinstance(old_labels, np.ndarray):
np_func = np_frompyfunc(label_change, 1, 1)
new_labels = np_func(old_labels)
else:
new_labels = label_change(old_labels)
else:
new_labels = _to_label_or_labels(label_change)
labels[old_indices] = new_labels

if new is None and not isinstance(old_or_changes, dict):
apply_changes(slice(None), old_or_changes)
elif new is not None:
apply_changes(old_or_changes, new)
else:
assert new is None and isinstance(old_or_changes, dict)
for old, new in old_or_changes.items():
apply_changes(old, new)
return Axis(labels, self.name)
apply = renamed_to(set_labels, 'apply')
replace = renamed_to(set_labels, 'replace')

# XXX: rename to named like Group?
def rename(self, name) -> 'Axis':
Expand All @@ -1196,7 +1218,7 @@ def rename(self, name) -> 'Axis':

Parameters
----------
name : str
name : str, Axis
the new name for the axis.

Returns
Expand Down Expand Up @@ -1252,7 +1274,7 @@ def union(self, other) -> 'Axis':
"""
if isinstance(other, str):
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other]
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other]
if isinstance(other, Axis):
other = other.labels
return Axis(unique_multi((self.labels, other)), self.name)
Expand Down Expand Up @@ -1288,7 +1310,7 @@ def intersection(self, other) -> 'Axis':
"""
if isinstance(other, str):
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other]
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other]
if isinstance(other, Axis):
other = other.labels
to_keep = set(other)
Expand Down Expand Up @@ -1325,7 +1347,7 @@ def difference(self, other) -> 'Axis':
"""
if isinstance(other, str):
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other]
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other]
if isinstance(other, Axis):
other = other.labels
to_drop = set(other)
Expand Down Expand Up @@ -2567,24 +2589,13 @@ def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'AxisCo
# handle {label1: new_label1, label2: new_label2}
if any(axis_ref not in self for axis_ref in changes.keys()):
changes_per_axis = defaultdict(list)
for selection, new_labels in changes.items():
for selection, label_changes in changes.items():
group = self._guess_axis(selection)
changes_per_axis[group.axis].append((selection, new_labels))
changes_per_axis[group.axis].append((group, label_changes))
changes = {axis: dict(axis_changes) for axis, axis_changes in changes_per_axis.items()}

new_axes = []
for old_axis, axis_changes in changes.items():
real_axis = self[old_axis]
if isinstance(axis_changes, dict):
new_axis = real_axis.replace(axis_changes)
# TODO: we should implement the non-dict behavior in Axis.replace, so that we can simplify this code to:
# new_axes = [self[old_axis].replace(axis_changes) for old_axis, axis_changes in changes.items()]
elif callable(axis_changes):
new_axis = real_axis.apply(axis_changes)
else:
new_axis = Axis(axis_changes, real_axis.name)
new_axes.append((real_axis, new_axis))
return self.replace(new_axes, inplace=inplace)
return self.replace({old_axis: self[old_axis].set_labels(axis_changes) for old_axis, axis_changes in
changes.items()}, inplace=inplace)

# TODO: deprecate method (should use __sub__ instead)
def without(self, axes) -> 'AxisCollection':
Expand Down Expand Up @@ -3428,6 +3439,7 @@ def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'Axis
See Also
--------
Array.align
Axis.align

Examples
--------
Expand Down
Loading