Skip to content

Commit

Permalink
Merge pull request #154 from adityagoel4512/add-ndonnx
Browse files Browse the repository at this point in the history
Add ndonnx
  • Loading branch information
asmeurer authored Jun 25, 2024
2 parents ac15c52 + 1ac19a4 commit 38ec1d4
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ jobs:
else
PIP_EXTRA='numpy==1.26.*'
fi
if [ "${{ matrix.python-version }}" == "3.9" ]; then
sed -i '/^ndonnx/d' requirements-dev.txt
fi
python -m pip install -r requirements-dev.txt $PIP_EXTRA
- name: Run Tests
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
for other array libraries, or if you encounter any issues, please [open an
issue](https://github.com/data-apis/array-api-compat/issues).
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).

See the documentation for more details https://data-apis.org/array-api-compat/
36 changes: 35 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def is_numpy_array(x):
is_array_api_obj
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
Expand Down Expand Up @@ -78,11 +79,12 @@ def is_cupy_array(x):
is_array_api_obj
is_numpy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing NumPy if it isn't already
# Avoid importing CuPy if it isn't already
if 'cupy' not in sys.modules:
return False

Expand Down Expand Up @@ -118,6 +120,33 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.
This function does not import ndonnx if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing torch if it isn't already
if 'ndonnx' not in sys.modules:
return False

import ndonnx as ndx

return isinstance(x, ndx.Array)

def is_dask_array(x):
"""
Return True if `x` is a dask.array Array.
Expand All @@ -133,6 +162,7 @@ def is_dask_array(x):
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_jax_array
is_pydata_sparse_array
"""
Expand Down Expand Up @@ -160,6 +190,7 @@ def is_jax_array(x):
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_pydata_sparse_array
"""
Expand Down Expand Up @@ -188,6 +219,7 @@ def is_pydata_sparse_array(x) -> bool:
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
"""
Expand All @@ -211,6 +243,7 @@ def is_array_api_obj(x):
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
"""
Expand Down Expand Up @@ -613,6 +646,7 @@ def size(x):
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"is_ndonnx_array",
"is_pydata_sparse_array",
"size",
"to_device",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ numpy
pytest
torch
sparse >=0.15.1
ndonnx

0 comments on commit 38ec1d4

Please sign in to comment.