Skip to content

Commit

Permalink
Added approximate KNN backend (#791)
Browse files Browse the repository at this point in the history
* Added approximate KNN backend

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com>

* Minor doc and annotation improvement after @eroell review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add pynndescent intersphinx mapping

* fix link, precise description of 3 argument options

* remove unwanted space

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com>
Co-authored-by: Eljas Roellin <eljas.roellin@ikmail.com>
  • Loading branch information
4 people authored Nov 14, 2024
1 parent 4a9e05d commit 5336a2e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"pandas": ("https://pandas.pydata.org/docs/", None),
"python": ("https://docs.python.org/3", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"pynndescent": ("https://pynndescent.readthedocs.io/en/latest/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"torch": ("https://pytorch.org/docs/master/", None),
"scanpy": ("https://scanpy.readthedocs.io/en/stable/", None),
Expand Down
79 changes: 51 additions & 28 deletions ehrapy/preprocessing/_scanpy_pp_api.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
from collections.abc import Collection, Mapping, Sequence
from __future__ import annotations

from types import MappingProxyType
from typing import Any, Callable, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Union

import numpy as np
import scanpy as sc
from anndata import AnnData
from scipy.sparse import spmatrix

if TYPE_CHECKING:
from collections.abc import Collection, Mapping, Sequence

from anndata import AnnData
from scanpy.neighbors import KnnTransformerLike
from scipy.sparse import spmatrix

from ehrapy.preprocessing._types import KnownTransformer

AnyRandom = Union[int, np.random.RandomState, None]


def pca(
data: Union[AnnData, np.ndarray, spmatrix],
n_comps: Optional[int] = None,
zero_center: Optional[bool] = True,
data: AnnData | np.ndarray | spmatrix,
n_comps: int | None = None,
zero_center: bool | None = True,
svd_solver: str = "arpack",
random_state: AnyRandom = 0,
return_info: bool = False,
dtype: str = "float32",
copy: bool = False,
chunked: bool = False,
chunk_size: Optional[int] = None,
) -> Union[AnnData, np.ndarray, spmatrix]: # pragma: no cover
chunk_size: int | None = None,
) -> AnnData | np.ndarray | spmatrix | None: # pragma: no cover
"""Computes a principal component analysis.
Computes PCA coordinates, loadings and variance decomposition. Uses the implementation of *scikit-learn*.
Expand Down Expand Up @@ -91,17 +99,17 @@ def pca(

def regress_out(
adata: AnnData,
keys: Union[str, Sequence[str]],
n_jobs: Optional[int] = None,
keys: str | Sequence[str],
n_jobs: int | None = None,
copy: bool = False,
) -> Optional[AnnData]: # pragma: no cover
) -> AnnData | None: # pragma: no cover
"""Regress out (mostly) unwanted sources of variation.
Uses simple linear regression. This is inspired by Seurat's `regressOut` function in R [Satija15].
Note that this function tends to overcorrect in certain circumstances.
Args:
adata: :class:`~anndata.AnnData` object object containing all observations.
adata: :class:`~anndata.AnnData` object containing all observations.
keys: Keys for observation annotation on which to regress on.
n_jobs: Number of jobs for parallel computation. `None` means using :attr:`scanpy._settings.ScanpyConfig.n_jobs`.
copy: Determines whether a copy of `adata` is returned.
Expand All @@ -113,12 +121,12 @@ def regress_out(


def subsample(
data: Union[AnnData, np.ndarray, spmatrix],
fraction: Optional[float] = None,
n_obs: Optional[int] = None,
data: AnnData | np.ndarray | spmatrix,
fraction: float | None = None,
n_obs: int | None = None,
random_state: AnyRandom = 0,
copy: bool = False,
) -> Optional[AnnData]: # pragma: no cover
) -> AnnData | None: # pragma: no cover
"""Subsample to a fraction of the number of observations.
Args:
Expand All @@ -138,9 +146,9 @@ def subsample(
def combat(
adata: AnnData,
key: str = "batch",
covariates: Optional[Collection[str]] = None,
covariates: Collection[str] | None = None,
inplace: bool = True,
) -> Union[AnnData, np.ndarray, None]: # pragma: no cover
) -> AnnData | np.ndarray | None: # pragma: no cover
"""ComBat function for batch effect correction [Johnson07]_ [Leek12]_ [Pedersen12]_.
Corrects for batch effects by fitting linear models, gains statistical power via an EB framework where information is borrowed across features.
Expand All @@ -149,7 +157,7 @@ def combat(
.. _combat.py: https://github.com/brentp/combat.py
Args:
adata: :class:`~anndata.AnnData` object object containing all observations.
adata: :class:`~anndata.AnnData` object containing all observations.
key: Key to a categorical annotation from :attr:`~anndata.AnnData.obs` that will be used for batch effect removal.
covariates: Additional covariates besides the batch variable such as adjustment variables or biological condition.
This parameter refers to the design matrix `X` in Equation 2.1 in [Johnson07]_ and to the `mod` argument in
Expand All @@ -163,7 +171,7 @@ def combat(
return sc.pp.combat(adata=adata, key=key, covariates=covariates, inplace=inplace)


_Method = Literal["umap", "gauss", "rapids"]
_Method = Literal["umap", "gauss"]
_MetricFn = Callable[[np.ndarray, np.ndarray], float]
_MetricSparseCapable = Literal["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"]
_MetricScipySpatial = Literal[
Expand Down Expand Up @@ -191,16 +199,17 @@ def combat(
def neighbors(
adata: AnnData,
n_neighbors: int = 15,
n_pcs: Optional[int] = None,
use_rep: Optional[str] = None,
n_pcs: int | None = None,
use_rep: str | None = None,
knn: bool = True,
random_state: AnyRandom = 0,
method: Optional[_Method] = "umap",
metric: Union[_Metric, _MetricFn] = "euclidean",
method: _Method = "umap",
transformer: KnnTransformerLike | KnownTransformer | None = None,
metric: _Metric | _MetricFn = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
key_added: Optional[str] = None,
key_added: str | None = None,
copy: bool = False,
) -> Optional[AnnData]: # pragma: no cover
) -> AnnData | None: # pragma: no cover
"""Compute a neighborhood graph of observations [McInnes18]_.
The neighbor search efficiency of this heavily relies on UMAP [McInnes18]_,
Expand All @@ -209,7 +218,7 @@ def neighbors(
connectivities are computed according to [Coifman05]_, in the adaption of [Haghverdi16]_.
Args:
adata: :class:`~anndata.AnnData` object object containing all observations.
adata: :class:`~anndata.AnnData` object containing all observations.
n_neighbors: The size of local neighborhood (in terms of number of neighboring data points) used for manifold approximation.
Larger values result in more global views of the manifold, while smaller values result in more local data being preserved.
In general values should be in the range 2 to 100. If `knn` is `True`, number of nearest neighbors to be searched.
Expand All @@ -225,6 +234,19 @@ def neighbors(
method: Use 'umap' [McInnes18]_ or 'gauss' (Gauss kernel following [Coifman05]_ with adaptive width [Haghverdi16]_) for computing connectivities.
Use 'rapids' for the RAPIDS implementation of UMAP (experimental, GPU only).
metric: A known metric’s name or a callable that returns a distance.
transformer: Approximate kNN search implementation. Follows the API of
:class:`~sklearn.neighbors.KNeighborsTransformer`.
See scanpy's `knn-transformers tutorial <https://scanpy.readthedocs.io/en/latest/how-to/knn-transformers.html>`_ for more details. This tutorial is also valid for ehrapy's `neighbors` function.
Next to the advanced options from the knn-transformers tutorial, this argument accepts the following basic options:
`None` (the default)
Behavior depends on data size.
For small data, uses :class:`~sklearn.neighbors.KNeighborsTransformer` with algorithm="brute" for exact kNN, otherwise uses
:class:`~pynndescent.pynndescent_.PyNNDescentTransformer` for approximate kNN.
`'pynndescent'`
Uses :class:`~pynndescent.pynndescent_.PyNNDescentTransformer` for approximate kNN.
`'sklearn'`
Uses :class:`~sklearn.neighbors.KNeighborsTransformer` with algorithm="brute" for exact kNN.
metric_kwds: Options for the metric.
key_added: If not specified, the neighbors data is stored in .uns['neighbors'],
distances and connectivities are stored in .obsp['distances'] and .obsp['connectivities'] respectively.
Expand All @@ -250,6 +272,7 @@ def neighbors(
knn=knn,
random_state=random_state,
method=method,
transformer=transformer,
metric=metric,
metric_kwds=metric_kwds,
key_added=key_added,
Expand Down
5 changes: 5 additions & 0 deletions ehrapy/preprocessing/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from typing import Literal

KnownTransformer = Literal["pynndescent", "sklearn"]

0 comments on commit 5336a2e

Please sign in to comment.