-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initial suggestions of array type handling on example of normalization methods #835
base: main
Are you sure you want to change the base?
Changes from all commits
877034d
5d5438b
0c1e041
d6dc2c9
621ea97
28251d7
33192d1
48f5936
3083b5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from functools import singledispatch | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
@@ -8,9 +9,14 @@ | |
from ehrapy._compat import is_dask_array | ||
|
||
try: | ||
import dask.array as da | ||
import dask_ml.preprocessing as daskml_pp | ||
|
||
DASK_AVAILABLE = True | ||
except ImportError: | ||
daskml_pp = None | ||
DASK_AVAILABLE = False | ||
|
||
|
||
from ehrapy.anndata.anndata_ext import ( | ||
assert_numeric_vars, | ||
|
@@ -69,6 +75,23 @@ def _scale_func_group( | |
return None | ||
|
||
|
||
@singledispatch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about systematically introducing singledispatch in this manner? It might be a bit of an overkill here; but I think in the long run, using this structure reduces code complexity by introducing a regularly appearing pattern. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree with your argument |
||
def _scale_norm_function(arr): | ||
raise NotImplementedError(f"scale_norm does not support data to be of type {type(arr)}") | ||
|
||
|
||
@_scale_norm_function.register | ||
def _(arr: np.ndarray, **kwargs): | ||
return sklearn_pp.StandardScaler(**kwargs).fit_transform | ||
|
||
|
||
if DASK_AVAILABLE: | ||
|
||
@_scale_norm_function.register | ||
def _(arr: da.Array, **kwargs): | ||
return daskml_pp.StandardScaler(**kwargs).fit_transform | ||
|
||
|
||
def scale_norm( | ||
adata: AnnData, | ||
vars: str | Sequence[str] | None = None, | ||
|
@@ -98,10 +121,7 @@ def scale_norm( | |
>>> adata_norm = ep.pp.scale_norm(adata, copy=True) | ||
""" | ||
|
||
if is_dask_array(adata.X): | ||
scale_func = daskml_pp.StandardScaler(**kwargs).fit_transform | ||
else: | ||
scale_func = sklearn_pp.StandardScaler(**kwargs).fit_transform | ||
scale_func = _scale_norm_function(adata.X, **kwargs) | ||
|
||
return _scale_func_group( | ||
adata=adata, | ||
|
@@ -113,6 +133,23 @@ def scale_norm( | |
) | ||
|
||
|
||
@singledispatch | ||
def _minmax_norm_function(arr): | ||
raise NotImplementedError(f"minmax_norm does not support data to be of type {type(arr)}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we should always suggest which types are supported. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can automate this by accessing |
||
|
||
|
||
@_minmax_norm_function.register | ||
def _(arr: np.ndarray, **kwargs): | ||
return sklearn_pp.MinMaxScaler(**kwargs).fit_transform | ||
|
||
|
||
if DASK_AVAILABLE: | ||
|
||
@_minmax_norm_function.register | ||
def _(arr: da.Array, **kwargs): | ||
return daskml_pp.MinMaxScaler(**kwargs).fit_transform | ||
|
||
|
||
def minmax_norm( | ||
adata: AnnData, | ||
vars: str | Sequence[str] | None = None, | ||
|
@@ -143,10 +180,7 @@ def minmax_norm( | |
>>> adata_norm = ep.pp.minmax_norm(adata, copy=True) | ||
""" | ||
|
||
if is_dask_array(adata.X): | ||
scale_func = daskml_pp.MinMaxScaler(**kwargs).fit_transform | ||
else: | ||
scale_func = sklearn_pp.MinMaxScaler(**kwargs).fit_transform | ||
scale_func = _minmax_norm_function(adata.X, **kwargs) | ||
|
||
return _scale_func_group( | ||
adata=adata, | ||
|
@@ -158,6 +192,16 @@ def minmax_norm( | |
) | ||
|
||
|
||
@singledispatch | ||
def _maxabs_norm_function(arr): | ||
raise NotImplementedError(f"maxabs_norm does not support data to be of type {type(arr)}") | ||
|
||
|
||
@_maxabs_norm_function.register | ||
def _(arr: np.ndarray): | ||
return sklearn_pp.MaxAbsScaler().fit_transform | ||
|
||
|
||
def maxabs_norm( | ||
adata: AnnData, | ||
vars: str | Sequence[str] | None = None, | ||
|
@@ -184,10 +228,8 @@ def maxabs_norm( | |
>>> adata = ep.dt.mimic_2(encoded=True) | ||
>>> adata_norm = ep.pp.maxabs_norm(adata, copy=True) | ||
""" | ||
if is_dask_array(adata.X): | ||
raise NotImplementedError("MaxAbsScaler is not implemented in dask_ml.") | ||
else: | ||
scale_func = sklearn_pp.MaxAbsScaler().fit_transform | ||
|
||
scale_func = _maxabs_norm_function(adata.X) | ||
|
||
return _scale_func_group( | ||
adata=adata, | ||
|
@@ -199,6 +241,23 @@ def maxabs_norm( | |
) | ||
|
||
|
||
@singledispatch | ||
def _robust_scale_norm_function(arr, **kwargs): | ||
raise NotImplementedError(f"robust_scale_norm does not support data to be of type {type(arr)}") | ||
|
||
|
||
@_robust_scale_norm_function.register | ||
def _(arr: np.ndarray, **kwargs): | ||
return sklearn_pp.RobustScaler(**kwargs).fit_transform | ||
|
||
|
||
if DASK_AVAILABLE: | ||
|
||
@_robust_scale_norm_function.register | ||
def _(arr: da.Array, **kwargs): | ||
return daskml_pp.RobustScaler(**kwargs).fit_transform | ||
|
||
|
||
def robust_scale_norm( | ||
adata: AnnData, | ||
vars: str | Sequence[str] | None = None, | ||
|
@@ -229,10 +288,8 @@ def robust_scale_norm( | |
>>> adata = ep.dt.mimic_2(encoded=True) | ||
>>> adata_norm = ep.pp.robust_scale_norm(adata, copy=True) | ||
""" | ||
if is_dask_array(adata.X): | ||
scale_func = daskml_pp.RobustScaler(**kwargs).fit_transform | ||
else: | ||
scale_func = sklearn_pp.RobustScaler(**kwargs).fit_transform | ||
|
||
scale_func = _robust_scale_norm_function(adata.X, **kwargs) | ||
|
||
return _scale_func_group( | ||
adata=adata, | ||
|
@@ -244,6 +301,23 @@ def robust_scale_norm( | |
) | ||
|
||
|
||
@singledispatch | ||
def _quantile_norm_function(arr): | ||
raise NotImplementedError(f"robust_scale_norm does not support data to be of type {type(arr)}") | ||
|
||
|
||
@_quantile_norm_function.register | ||
def _(arr: np.ndarray, **kwargs): | ||
return sklearn_pp.QuantileTransformer(**kwargs).fit_transform | ||
|
||
|
||
if DASK_AVAILABLE: | ||
|
||
@_quantile_norm_function.register | ||
def _(arr: da.Array, **kwargs): | ||
return daskml_pp.QuantileTransformer(**kwargs).fit_transform | ||
|
||
|
||
def quantile_norm( | ||
adata: AnnData, | ||
vars: str | Sequence[str] | None = None, | ||
|
@@ -273,10 +347,8 @@ def quantile_norm( | |
>>> adata = ep.dt.mimic_2(encoded=True) | ||
>>> adata_norm = ep.pp.quantile_norm(adata, copy=True) | ||
""" | ||
if is_dask_array(adata.X): | ||
scale_func = daskml_pp.QuantileTransformer(**kwargs).fit_transform | ||
else: | ||
scale_func = sklearn_pp.QuantileTransformer(**kwargs).fit_transform | ||
|
||
scale_func = _quantile_norm_function(adata.X, **kwargs) | ||
|
||
return _scale_func_group( | ||
adata=adata, | ||
|
@@ -288,6 +360,16 @@ def quantile_norm( | |
) | ||
|
||
|
||
@singledispatch | ||
def _power_norm_function(arr, **kwargs): | ||
raise NotImplementedError(f"power_norm does not support data to be of type {type(arr)}") | ||
|
||
|
||
@_power_norm_function.register | ||
def _(arr: np.ndarray, **kwargs): | ||
return sklearn_pp.PowerTransformer(**kwargs).fit_transform | ||
|
||
|
||
def power_norm( | ||
adata: AnnData, | ||
vars: str | Sequence[str] | None = None, | ||
|
@@ -317,10 +399,8 @@ def power_norm( | |
>>> adata = ep.dt.mimic_2(encoded=True) | ||
>>> adata_norm = ep.pp.power_norm(adata, copy=True) | ||
""" | ||
if is_dask_array(adata.X): | ||
raise NotImplementedError("dask-ml has no PowerTransformer, this is only available in scikit-learn") | ||
else: | ||
scale_func = sklearn_pp.PowerTransformer(**kwargs).fit_transform | ||
|
||
scale_func = _power_norm_function(adata.X, **kwargs) | ||
|
||
return _scale_func_group( | ||
adata=adata, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is your take on such a flag?
Could it be worth to introduce dask as a dependency now that we try to introduce it systematically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anndata[dask] I guess to be compatible but fine with me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. If dask as dependency, then as
anndata[dask]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder whether we should have that flag here though and not more global somewhere. Else we need that check often. If we check globally though, we might import dask when it's not needed which is a performance penalty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends on how slow dask is to import. This varies wildly depending on the project.
E.g. AnnData is a little slow because of pandas. In scanpy we lazy-import some sklearn subpackage because that would have a huge impact, and without it we’re only slower than anndata because of
sklearn.utils
andnumba
.