Skip to content

Commit

Permalink
Rename is_sparse_array -> is_pydata_sparse.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed May 15, 2024
1 parent b92a35c commit 7ebc3c0
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def is_numpy_array(x):
is_torch_array
is_dask_array
is_jax_array
is_sparse_array
is_pydata_sparse
"""
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
Expand Down Expand Up @@ -80,7 +80,7 @@ def is_cupy_array(x):
is_torch_array
is_dask_array
is_jax_array
is_sparse_array
is_pydata_sparse
"""
# Avoid importing NumPy if it isn't already
if 'cupy' not in sys.modules:
Expand All @@ -107,7 +107,7 @@ def is_torch_array(x):
is_cupy_array
is_dask_array
is_jax_array
is_sparse_array
is_pydata_sparse
"""
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
Expand All @@ -134,7 +134,7 @@ def is_dask_array(x):
is_cupy_array
is_torch_array
is_jax_array
is_sparse_array
is_pydata_sparse
"""
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
Expand All @@ -161,7 +161,7 @@ def is_jax_array(x):
is_cupy_array
is_torch_array
is_dask_array
is_sparse_array
is_pydata_sparse
"""
# Avoid importing jax if it isn't already
if 'jax' not in sys.modules:
Expand All @@ -172,7 +172,7 @@ def is_jax_array(x):
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_sparse_array(x) -> bool:
def is_pydata_sparse(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -219,7 +219,7 @@ def is_array_api_obj(x):
or is_torch_array(x) \
or is_dask_array(x) \
or is_jax_array(x) \
or is_sparse_array(x) \
or is_pydata_sparse(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -288,7 +288,7 @@ def your_function(x, y):
is_torch_array
is_dask_array
is_jax_array
is_sparse_array
is_pydata_sparse
"""
if use_compat not in [None, True, False]:
Expand Down Expand Up @@ -348,7 +348,7 @@ def your_function(x, y):
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_sparse_array(x):
elif is_pydata_sparse(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
Expand Down Expand Up @@ -451,7 +451,7 @@ def device(x: Array, /) -> Device:
return x.device()
else:
return x.device
elif is_sparse_array(x):
elif is_pydata_sparse(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
if x_device is not None:
Expand Down Expand Up @@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# This import adds to_device to x
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
elif is_sparse_array(x) and device == _device(x):
elif is_pydata_sparse(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
return x
Expand Down Expand Up @@ -613,7 +613,7 @@ def size(x):
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"is_sparse_array",
"is_pydata_sparse",
"size",
"to_device",
]
Expand Down

0 comments on commit 7ebc3c0

Please sign in to comment.