Skip to content

Commit

Permalink
Fix it
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Aug 29, 2024
1 parent d79beea commit 7e751fa
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 33 deletions.
5 changes: 1 addition & 4 deletions src/anndata2ri/_r2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import _conv_name
from ._conv import converter, mat_rpy2py
from ._rpy2_ext import importr
from ._rpy2_ext import R_INT_BYTES, importr
from .scipy2ri import supported_r_matrix_classes
from .scipy2ri._r2py import rmat_to_spmat

Expand All @@ -23,9 +23,6 @@
from scipy.sparse import spmatrix


R_INT_BYTES = 4


@converter.rpy2py.register(SexpS4)
def rpy2py_s4(obj: SexpS4) -> pd.DataFrame | AnnData | None:
"""Convert known S4 class instance to Python object.
Expand Down
9 changes: 6 additions & 3 deletions src/anndata2ri/_rpy2_ext.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from __future__ import annotations

from functools import lru_cache
from functools import cache

from rpy2.robjects import Environment, packages


@lru_cache
R_INT_BYTES = 4


@cache
def importr(name: str) -> packages.Package:
return packages.importr(name)


@lru_cache
@cache
def data(package: str, name: str | None = None) -> packages.PackageData | Environment:
if name is None:
return packages.data(importr(package))
Expand Down
42 changes: 21 additions & 21 deletions src/anndata2ri/scipy2ri/_py2r.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

from functools import lru_cache, wraps
from functools import cache, wraps
from importlib.resources import files
from typing import TYPE_CHECKING

import numpy as np
from rpy2.robjects import default_converter, numpy2ri
from rpy2.robjects.conversion import localconverter
from rpy2.robjects.packages import Package, SignatureTranslatedAnonymousPackage
from rpy2.robjects.packages import InstalledSTPackage, SignatureTranslatedAnonymousPackage
from scipy import sparse

from anndata2ri._rpy2_ext import importr
from anndata2ri._rpy2_ext import R_INT_BYTES, importr

from ._conv import converter

Expand All @@ -21,21 +21,26 @@
from rpy2.rinterface import Sexp


matrix: SignatureTranslatedAnonymousPackage | None = None
base: Package | None = None
@cache
def baseenv() -> InstalledSTPackage:
return importr('base')


@lru_cache
def get_r_code() -> str:
return files('anndata2ri').joinpath('scipy2ri', '_py2r_helpers.r').read_text()
@cache
def matrixenv() -> SignatureTranslatedAnonymousPackage:
importr('Matrix') # make class available
r_code = files('anndata2ri').joinpath('scipy2ri', '_py2r_helpers.r').read_text()
return SignatureTranslatedAnonymousPackage(r_code, 'matrix')


def get_type_conv(dtype: np.dtype) -> Callable[[np.ndarray], Sexp]:
global base # noqa: PLW0603
if base is None:
base = importr('base')
base = baseenv()
if np.issubdtype(dtype, np.floating):
return base.as_double
if np.issubdtype(dtype, np.integer):
if dtype.itemsize <= R_INT_BYTES:
return base.as_integer
return base.as_numeric # maybe uses R_xlen_t?
if np.issubdtype(dtype, np.bool_):
return base.as_logical
msg = f'Unknown dtype {dtype!r} cannot be converted to ?gRMatrix.'
Expand All @@ -47,12 +52,7 @@ def py2r_context(f: Callable[[sparse.spmatrix], Sexp]) -> Callable[[sparse.spmat

@wraps(f)
def wrapper(obj: sparse.spmatrix) -> Sexp:
global matrix # noqa: PLW0603
if matrix is None:
importr('Matrix') # make class available
r_code = get_r_code()
matrix = SignatureTranslatedAnonymousPackage(r_code, 'matrix')

matrixenv() # make Matrix class available
return f(obj)

return wrapper
Expand All @@ -64,7 +64,7 @@ def csc_to_rmat(csc: sparse.csc_matrix) -> Sexp:
csc.sort_indices()
conv_data = get_type_conv(csc.dtype)
with localconverter(default_converter + numpy2ri.converter):
return matrix.from_csc(i=csc.indices, p=csc.indptr, x=csc.data, dims=list(csc.shape), conv_data=conv_data)
return matrixenv().from_csc(i=csc.indices, p=csc.indptr, x=csc.data, dims=list(csc.shape), conv_data=conv_data)


@converter.py2rpy.register(sparse.csr_matrix)
Expand All @@ -73,7 +73,7 @@ def csr_to_rmat(csr: sparse.csr_matrix) -> Sexp:
csr.sort_indices()
conv_data = get_type_conv(csr.dtype)
with localconverter(default_converter + numpy2ri.converter):
return matrix.from_csr(
return matrixenv().from_csr(
j=csr.indices,
p=csr.indptr,
x=csr.data,
Expand All @@ -87,7 +87,7 @@ def csr_to_rmat(csr: sparse.csr_matrix) -> Sexp:
def coo_to_rmat(coo: sparse.coo_matrix) -> Sexp:
conv_data = get_type_conv(coo.dtype)
with localconverter(default_converter + numpy2ri.converter):
return matrix.from_coo(
return matrixenv().from_coo(
i=coo.row,
j=coo.col,
x=coo.data,
Expand All @@ -107,7 +107,7 @@ def dia_to_rmat(dia: sparse.dia_matrix) -> Sexp:
)
raise ValueError(msg)
with localconverter(default_converter + numpy2ri.converter):
return matrix.from_dia(
return matrixenv().from_dia(
n=dia.shape[0],
x=dia.data,
conv_data=conv_data,
Expand Down
8 changes: 3 additions & 5 deletions tests/test_py2rpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@ def test_simple(


def krumsiek() -> AnnData:
with (
pytest.warns(UserWarning, match=r'Duplicated obs_names'),
pytest.warns(UserWarning, match=r'Observation names are not unique'),
):
return sc.datasets.krumsiek11()
adata = sc.datasets.krumsiek11()
adata.obs_names_make_unique()
return adata


def check_empty(_: Sexp) -> None:
Expand Down

0 comments on commit 7e751fa

Please sign in to comment.