-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More powerful setlabels #985
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
Syntax changes | ||
^^^^^^^^^^^^^^ | ||
|
||
* renamed ``Array.old_method_name()`` to :py:obj:`Array.new_method_name()` (closes :issue:`1`). | ||
* renamed ``Axis.apply()`` and ``Axis.replace()`` are deprecated in favor of :py:obj:`Axis.set_labels()`. | ||
|
||
* renamed ``old_argument_name`` argument of :py:obj:`Array.method_name()` to ``new_argument_name``. | ||
|
||
|
@@ -52,6 +52,24 @@ Miscellaneous improvements | |
* made all I/O functions/methods/constructors to accept either a string or a pathlib.Path object | ||
for all arguments representing a path (closes :issue:`896`). | ||
|
||
* :py:obj:`Array.set_labels()` and :py:obj:`Axis.set_labels()` (formerly ``Axis.replace()`` and ``Axis.apply()``) now | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you check if either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea, I forgot about the tutorial. |
||
accepts slices, Groups or selection strings as labels to change and callable and "creation strings" as new labels, so | ||
that it is easier to change only a subset of labels or to change several labels in the same way (closes :issue:`906`). | ||
|
||
>>> arr = ndtest((2, 3)) | ||
>>> arr | ||
a\b b0 b1 b2 | ||
a0 0 1 2 | ||
a1 3 4 5 | ||
>>> arr.set_labels({'b1:': str.upper, 'a1': 'A-ONE'}) | ||
a\b b0 B1 B2 | ||
a0 0 1 2 | ||
A-ONE 3 4 5 | ||
>>> arr.set_labels('b1:', 'B1..B2') | ||
a\b b0 B1 B2 | ||
a0 0 1 2 | ||
a1 3 4 5 | ||
|
||
* added type hints for all remaining functions and methods which improves autocompletion in editors (such as PyCharm). | ||
Closes :issue:`864`. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,9 @@ | |
|
||
from larray.core.abstractbases import ABCAxis, ABCAxisReference, ABCArray | ||
from larray.core.expr import ExprNode | ||
from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_tick, _to_ticks, _to_key, _seq_summary, | ||
_idx_seq_to_slice, _seq_group_to_name, _translate_group_key_hdf, remove_nested_groups) | ||
from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_label, _to_labels, _to_key, _seq_summary, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe mention in the commit message that you renamed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx |
||
_idx_seq_to_slice, _seq_group_to_name, _translate_group_key_hdf, remove_nested_groups, | ||
_to_label_or_labels) | ||
from larray.util.oset import OrderedSet | ||
from larray.util.misc import (duplicates, array_lookup2, ReprString, index_by_id, renamed_to, common_type, LHDFStore, | ||
lazy_attribute, _isnoneslice, unique_list, unique_multi, Product, argsort, has_duplicates, | ||
|
@@ -195,7 +196,7 @@ def labels(self, labels): | |
labels = np.arange(length) | ||
iswildcard = True | ||
else: | ||
labels = _to_ticks(labels, parse_single_int=True) | ||
labels = _to_labels(labels, parse_single_int=True) | ||
length = len(labels) | ||
iswildcard = False | ||
|
||
|
@@ -883,7 +884,7 @@ def _ipython_key_completions_(self) -> List[Scalar]: | |
|
||
def __contains__(self, key) -> bool: | ||
# TODO: ideally, _to_tick shouldn't be necessary, the __hash__ and __eq__ of Group should include this | ||
return _to_tick(key) in self._mapping | ||
return _to_label(key) in self._mapping | ||
|
||
# use the default hash. We have to specify it explicitly because we define __eq__ | ||
__hash__ = object.__hash__ | ||
|
@@ -905,7 +906,7 @@ def index(self, key) -> Union[int, np.ndarray, slice]: | |
|
||
Returns | ||
------- | ||
(array of) int | ||
int, slice, np.ndarray or Arrray | ||
Numerical index(ices) of (all) label(s) represented by the key | ||
|
||
Notes | ||
|
@@ -919,14 +920,17 @@ def index(self, key) -> Union[int, np.ndarray, slice]: | |
3 | ||
>>> people.index(people.containing('Bruce')) | ||
array([1, 2]) | ||
>>> a = Axis('a0..a5', 'a') | ||
>>> a.index('a1,a3,a2..a4') | ||
array([1, 3, 2, 3, 4]) | ||
""" | ||
mapping = self._mapping | ||
|
||
if isinstance(key, Group) and key.axis is not self and key.axis is not None: | ||
try: | ||
# XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last | ||
# resort | ||
potential_tick = _to_tick(key) | ||
potential_tick = _to_label(key) | ||
# avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test | ||
if self._is_key_type_compatible(potential_tick): | ||
return mapping[potential_tick] | ||
|
@@ -1121,73 +1125,91 @@ def copy(self) -> 'Axis': | |
new_axis.__sorted_values = self.__sorted_values | ||
return new_axis | ||
|
||
def replace(self, old, new=None) -> 'Axis': | ||
def set_labels(self, old_or_changes, new=None) -> 'Axis': | ||
r""" | ||
Returns a new axis with some labels replaced. | ||
Returns a new axis with some labels changed. | ||
|
||
Parameters | ||
---------- | ||
old : any scalar (bool, int, str, ...), tuple/list/array of scalars, or a mapping. | ||
the label(s) to be replaced. Old can be a mapping {old1: new1, old2: new2, ...} | ||
new : any scalar (bool, int, str, ...) or tuple/list/array of scalars, optional | ||
the new label(s). This is argument must not be used if old is a mapping. | ||
It supports three distinct syntax variants: | ||
|
||
Returns | ||
------- | ||
Axis | ||
a new Axis with the old labels replaced by new labels. | ||
* Axis.set_labels(new_labels) -> replace all Axis labels by `new_labels` | ||
* Axis.set_labels(label_selection, new_labels) -> replace selection of labels by `new_labels` | ||
* Axis.set_labels({old1: new1, old2: new2}) -> replace each selection of labels by corresponding new labels | ||
|
||
Examples | ||
-------- | ||
>>> sex = Axis('sex=M,F') | ||
>>> sex | ||
Axis(['M', 'F'], 'sex') | ||
>>> sex.replace('M', 'Male') | ||
Axis(['Male', 'F'], 'sex') | ||
>>> sex.replace({'M': 'Male', 'F': 'Female'}) | ||
Axis(['Male', 'Female'], 'sex') | ||
>>> sex.replace(['M', 'F'], ['Male', 'Female']) | ||
Axis(['Male', 'Female'], 'sex') | ||
""" | ||
if isinstance(old, dict): | ||
new = list(old.values()) | ||
old = list(old.keys()) | ||
elif np.isscalar(old): | ||
assert new is not None and np.isscalar(new), f"{new} is not a scalar but a {type(new).__name__}" | ||
old = [old] | ||
new = [new] | ||
else: | ||
seq = (tuple, list, np.ndarray) | ||
assert isinstance(old, seq), f"{old} is not a sequence but a {type(old).__name__}" | ||
assert isinstance(new, seq), f"{new} is not a sequence but a {type(new).__name__}" | ||
assert len(old) == len(new) | ||
# using object dtype because new labels length can be larger than the fixed str length in the self.labels array | ||
labels = self.labels.astype(object) | ||
indices = self.index(old) | ||
labels[indices] = new | ||
return Axis(labels, self.name) | ||
|
||
def apply(self, func) -> 'Axis': | ||
r""" | ||
Returns a new axis with the labels transformed by func. | ||
Additionally, new labels in any of the above forms can be a function which transforms the existing | ||
labels to produce the actual new labels. | ||
|
||
Parameters | ||
---------- | ||
func : callable | ||
A callable which takes a single argument and returns a single value. | ||
old_or_changes : any scalar (bool, int, str, ...), tuple/list/array of scalars, Group, callable or mapping. | ||
This can be either: | ||
|
||
* A selection of label(s) to be replaced. This can take several forms: | ||
- a single label (e.g. 'France') | ||
- a list of labels (e.g. ['France', 'Germany']) | ||
- a comma-separated string of labels (e.g. 'France,Germany') | ||
- a Group (e.g. country['France']) | ||
* A mapping {selection1: new_labels1, selection2: new_labels2, ...} | ||
* New labels, in which case all the axis labels will be replaced by these new labels and | ||
the `new` argument must not be used. | ||
new : any scalar (bool, int, str, ...) or tuple/list/array of scalars or callable, optional | ||
The new label(s) or function to apply to old labels to get the new labels. This is argument must not be | ||
used if `old_or_changes` contains the new labels or if it is a mapping. | ||
|
||
Returns | ||
------- | ||
Axis | ||
a new Axis with the transformed labels. | ||
a new Axis with the old labels replaced by new labels. | ||
|
||
Examples | ||
-------- | ||
>>> sex = Axis('sex=MALE,FEMALE') | ||
>>> sex.apply(str.capitalize) | ||
Axis(['Male', 'Female'], 'sex') | ||
""" | ||
return Axis(np_frompyfunc(func, 1, 1)(self.labels), self.name) | ||
>>> country = Axis('country=be,de,fr') | ||
>>> country | ||
Axis(['be', 'de', 'fr'], 'country') | ||
>>> country.set_labels('be', 'Belgium') | ||
Axis(['Belgium', 'de', 'fr'], 'country') | ||
>>> country.set_labels({'de': 'Germany', 'fr': 'France'}) | ||
Axis(['be', 'Germany', 'France'], 'country') | ||
>>> country.set_labels(['be', 'fr'], ['Belgium', 'France']) | ||
Axis(['Belgium', 'de', 'France'], 'country') | ||
>>> country.set_labels('be,de', 'Belgium-Germany') | ||
Axis(['Belgium-Germany', 'Belgium-Germany', 'fr'], 'country') | ||
>>> country.set_labels('be,de', ['Belgium', 'Germany']) | ||
Axis(['Belgium', 'Germany', 'fr'], 'country') | ||
>>> country.set_labels(str.upper) | ||
Axis(['BE', 'DE', 'FR'], 'country') | ||
""" | ||
# FIXME: compute max(length of new keys and old labels array) instead | ||
# XXX: it might be easier to go via list to get the label type auto-detection | ||
# labels = self.labels.tolist() | ||
|
||
# using object dtype because new labels length can be larger than the fixed str length in self.labels | ||
labels = self.labels.astype(object) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will the labels of the returned axis be always of the type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, they will. That's clearly suboptimal but already better than broken labels IMO. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, maybe store the dtype in the beginning of the method and re-apply There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not as easy, because we can easily get mixed types labels in the result (especially when making changes to only a subset of labels) even when the original dtype is not str. Imagine applying str.format() or whatever function which converts integers to strings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If one gets mixed types labels in the result, shouldn't we force her/him to first convert the type of the labels to str? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Converting everything to string is painful/surprising too. The usual case where you want a mixed type axis is when you have an integer axis (e.g. age) and you add a "total" label, or the opposite: you have a string axis with some special aggregate labels ("total", etc.) and want to convert all "number strings" to integers, but not the special labels. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes indeed. I forgot that specific case. I guess the current implementation is OK then. |
||
get_indices = self.index | ||
|
||
def apply_changes(selection, label_change): | ||
old_indices = get_indices(selection) | ||
if callable(label_change): | ||
old_labels = labels[old_indices] | ||
if isinstance(old_labels, np.ndarray): | ||
np_func = np_frompyfunc(label_change, 1, 1) | ||
new_labels = np_func(old_labels) | ||
else: | ||
new_labels = label_change(old_labels) | ||
else: | ||
new_labels = _to_label_or_labels(label_change) | ||
labels[old_indices] = new_labels | ||
|
||
if new is None and not isinstance(old_or_changes, dict): | ||
apply_changes(slice(None), old_or_changes) | ||
elif new is not None: | ||
apply_changes(old_or_changes, new) | ||
else: | ||
assert new is None and isinstance(old_or_changes, dict) | ||
for old, new in old_or_changes.items(): | ||
apply_changes(old, new) | ||
return Axis(labels, self.name) | ||
apply = renamed_to(set_labels, 'apply') | ||
replace = renamed_to(set_labels, 'replace') | ||
|
||
# XXX: rename to named like Group? | ||
def rename(self, name) -> 'Axis': | ||
|
@@ -1196,7 +1218,7 @@ def rename(self, name) -> 'Axis': | |
|
||
Parameters | ||
---------- | ||
name : str | ||
name : str, Axis | ||
the new name for the axis. | ||
|
||
Returns | ||
|
@@ -1252,7 +1274,7 @@ def union(self, other) -> 'Axis': | |
""" | ||
if isinstance(other, str): | ||
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed | ||
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other] | ||
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other] | ||
if isinstance(other, Axis): | ||
other = other.labels | ||
return Axis(unique_multi((self.labels, other)), self.name) | ||
|
@@ -1288,7 +1310,7 @@ def intersection(self, other) -> 'Axis': | |
""" | ||
if isinstance(other, str): | ||
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed | ||
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other] | ||
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other] | ||
if isinstance(other, Axis): | ||
other = other.labels | ||
to_keep = set(other) | ||
|
@@ -1325,7 +1347,7 @@ def difference(self, other) -> 'Axis': | |
""" | ||
if isinstance(other, str): | ||
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed | ||
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other] | ||
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other] | ||
if isinstance(other, Axis): | ||
other = other.labels | ||
to_drop = set(other) | ||
|
@@ -2567,24 +2589,13 @@ def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'AxisCo | |
# handle {label1: new_label1, label2: new_label2} | ||
if any(axis_ref not in self for axis_ref in changes.keys()): | ||
changes_per_axis = defaultdict(list) | ||
for selection, new_labels in changes.items(): | ||
for selection, label_changes in changes.items(): | ||
group = self._guess_axis(selection) | ||
changes_per_axis[group.axis].append((selection, new_labels)) | ||
changes_per_axis[group.axis].append((group, label_changes)) | ||
changes = {axis: dict(axis_changes) for axis, axis_changes in changes_per_axis.items()} | ||
|
||
new_axes = [] | ||
for old_axis, axis_changes in changes.items(): | ||
real_axis = self[old_axis] | ||
if isinstance(axis_changes, dict): | ||
new_axis = real_axis.replace(axis_changes) | ||
# TODO: we should implement the non-dict behavior in Axis.replace, so that we can simplify this code to: | ||
# new_axes = [self[old_axis].replace(axis_changes) for old_axis, axis_changes in changes.items()] | ||
elif callable(axis_changes): | ||
new_axis = real_axis.apply(axis_changes) | ||
else: | ||
new_axis = Axis(axis_changes, real_axis.name) | ||
new_axes.append((real_axis, new_axis)) | ||
return self.replace(new_axes, inplace=inplace) | ||
return self.replace({old_axis: self[old_axis].set_labels(axis_changes) for old_axis, axis_changes in | ||
changes.items()}, inplace=inplace) | ||
|
||
# TODO: deprecate method (should use __sub__ instead) | ||
def without(self, axes) -> 'AxisCollection': | ||
|
@@ -3428,6 +3439,7 @@ def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'Axis | |
See Also | ||
-------- | ||
Array.align | ||
Axis.align | ||
|
||
Examples | ||
-------- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/renamed//
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are
Axis.apply()
and/orAxis.replace()
still mentioned in the api.rst file?If yes, please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You made me realize I forgot to update api.rst for this PR. So yes, they are still mentioned and Axis.set_labels is not.